diff --git a/.aztec-sync-commit b/.aztec-sync-commit index 063664dceac..477ebbca903 100644 --- a/.aztec-sync-commit +++ b/.aztec-sync-commit @@ -1 +1 @@ -b8bace9a00c3a8eb93f42682e8cbfa351fc5238c +0577c1a70e9746bd06f07d2813af1be39e01ca02 diff --git a/.github/ACVM_NOT_PUBLISHABLE.md b/.github/ACVM_NOT_PUBLISHABLE.md index 33230f8e8d8..06c9505ebae 100644 --- a/.github/ACVM_NOT_PUBLISHABLE.md +++ b/.github/ACVM_NOT_PUBLISHABLE.md @@ -5,7 +5,7 @@ assignees: TomAFrench, Savio-Sou The ACVM crates are currently unpublishable, making a release will NOT push our crates to crates.io. -This is likely due to a crate we depend on bumping its MSRV above our own. Our lockfile is not taken into account when publishing to crates.io (as people downloading our crate don't use it) so we need to be able to use the most up to date versions of our dependencies (including transient dependencies) specified. +This is likely due to a crate we depend on bumping its MSRV above our own. Our lockfile is not taken into account when publishing to crates.io (as people downloading our crate don't use it) so we need to be able to use the most up-to-date versions of our dependencies (including transient dependencies) specified. Check the [MSRV check]({{env.WORKFLOW_URL}}) workflow for details. diff --git a/.github/workflows/memory_report.yml b/.github/workflows/memory_report.yml new file mode 100644 index 00000000000..c31c750e6fc --- /dev/null +++ b/.github/workflows/memory_report.yml @@ -0,0 +1,88 @@ +name: Report Peak Memory + +on: + push: + branches: + - master + pull_request: + +jobs: + build-nargo: + runs-on: ubuntu-latest + strategy: + matrix: + target: [x86_64-unknown-linux-gnu] + + steps: + - name: Checkout Noir repo + uses: actions/checkout@v4 + + - name: Setup toolchain + uses: dtolnay/rust-toolchain@1.74.1 + + - uses: Swatinem/rust-cache@v2 + with: + key: ${{ matrix.target }} + cache-on-failure: true + save-if: ${{ github.event_name != 'merge_group' }} + + - name: Build Nargo + run: cargo build --package nargo_cli --release + + - name: Package artifacts + run: | + mkdir dist + cp ./target/release/nargo ./dist/nargo + + - name: Upload artifact + uses: actions/upload-artifact@v4 + with: + name: nargo + path: ./dist/* + retention-days: 3 + + generate_memory_report: + needs: [build-nargo] + runs-on: ubuntu-latest + permissions: + pull-requests: write + + steps: + - uses: actions/checkout@v4 + + - name: Download nargo binary + uses: actions/download-artifact@v4 + with: + name: nargo + path: ./nargo + + - name: Set nargo on PATH + run: | + nargo_binary="${{ github.workspace }}/nargo/nargo" + chmod +x $nargo_binary + echo "$(dirname $nargo_binary)" >> $GITHUB_PATH + export PATH="$PATH:$(dirname $nargo_binary)" + nargo -V + + - name: Generate Memory report + working-directory: ./test_programs + run: | + chmod +x memory_report.sh + ./memory_report.sh + mv memory_report.json ../memory_report.json + + - name: Parse memory report + id: memory_report + uses: noir-lang/noir-bench-report@ccb0d806a91d3bd86dba0ba3d580a814eed5673c + with: + report: memory_report.json + header: | + # Memory Report + memory_report: true + + - name: Add memory report to sticky comment + if: github.event_name == 'pull_request' || github.event_name == 'pull_request_target' + uses: marocchino/sticky-pull-request-comment@v2 + with: + header: memory + message: ${{ steps.memory_report.outputs.markdown }} \ No newline at end of file diff --git a/.github/workflows/test-js-packages.yml b/.github/workflows/test-js-packages.yml index 152d8b1653e..422a30ed08f 100644 --- a/.github/workflows/test-js-packages.yml +++ b/.github/workflows/test-js-packages.yml @@ -478,6 +478,9 @@ jobs: - name: Install Foundry uses: foundry-rs/foundry-toolchain@v1.2.0 + with: + version: nightly-8660e5b941fe7f4d67e246cfd3dafea330fb53b1 + - name: Install `bb` run: | @@ -516,12 +519,25 @@ jobs: fail-fast: false matrix: project: - # Disabled as these are currently failing with many visibility errors - - { repo: AztecProtocol/aztec-nr, path: ./ } + - { repo: noir-lang/ec, path: ./ } + - { repo: noir-lang/eddsa, path: ./ } + - { repo: noir-lang/mimc, path: ./ } + - { repo: noir-lang/noir_sort, path: ./ } + - { repo: noir-lang/noir-edwards, path: ./ } + - { repo: noir-lang/noir-bignum, path: ./ } + - { repo: noir-lang/noir_bigcurve, path: ./ } + - { repo: noir-lang/noir_base64, path: ./ } + - { repo: noir-lang/noir_string_search, path: ./ } + - { repo: noir-lang/sparse_array, path: ./ } + - { repo: noir-lang/noir_rsa, path: ./lib } + - { repo: AztecProtocol/aztec-packages, path: ./noir-projects/aztec-nr } - { repo: AztecProtocol/aztec-packages, path: ./noir-projects/noir-contracts } - # Disabled as aztec-packages requires a setup-step in order to generate a `Nargo.toml` - #- { repo: AztecProtocol/aztec-packages, path: ./noir-projects/noir-protocol-circuits } - - { repo: noir-lang/noir-edwards, path: ./, ref: 3188ea74fe3b059219a2ea87899589c266256d74 } + - { repo: AztecProtocol/aztec-packages, path: ./noir-projects/noir-protocol-circuits/crates/parity-lib } + - { repo: AztecProtocol/aztec-packages, path: ./noir-projects/noir-protocol-circuits/crates/private-kernel-lib } + - { repo: AztecProtocol/aztec-packages, path: ./noir-projects/noir-protocol-circuits/crates/reset-kernel-lib } + - { repo: AztecProtocol/aztec-packages, path: ./noir-projects/noir-protocol-circuits/crates/rollup-lib } + - { repo: AztecProtocol/aztec-packages, path: ./noir-projects/noir-protocol-circuits/crates/types } + name: Check external repo - ${{ matrix.project.repo }} steps: - name: Checkout @@ -551,9 +567,12 @@ jobs: # Github actions seems to not expand "**" in globs by default. shopt -s globstar sed -i '/^compiler_version/d' ./**/Nargo.toml - - name: Run nargo check + + - name: Run nargo test working-directory: ./test-repo/${{ matrix.project.path }} - run: nargo check + run: nargo test --silence-warnings + env: + NARGO_IGNORE_TEST_FAILURES_FROM_FOREIGN_CALLS: true # This is a job which depends on all test jobs and reports the overall status. # This allows us to add/remove test jobs without having to update the required workflows. diff --git a/Cargo.lock b/Cargo.lock index aacd8f7e596..2f19ed704b2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -805,9 +805,9 @@ dependencies = [ [[package]] name = "clap_complete" -version = "4.5.38" +version = "4.5.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9647a559c112175f17cf724dc72d3645680a883c58481332779192b0d8e7a01" +checksum = "11611dca53440593f38e6b25ec629de50b14cdfa63adc0fb856115a2c6d97595" dependencies = [ "clap", ] @@ -2789,6 +2789,7 @@ dependencies = [ "dirs", "file-lock", "fm", + "fxhash", "iai", "iter-extended", "lazy_static", @@ -3151,6 +3152,7 @@ dependencies = [ "serde_json", "serde_with", "similar-asserts", + "test-case", "thiserror", "tracing", ] diff --git a/Cargo.toml b/Cargo.toml index d6a56b7990d..94ebe54fde1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -157,6 +157,8 @@ proptest-derive = "0.4.0" rayon = "1.8.0" sha2 = { version = "0.10.6", features = ["compress"] } sha3 = "0.10.6" +strum = "0.24" +strum_macros = "0.24" im = { version = "15.1", features = ["serde"] } tracing = "0.1.40" diff --git a/acvm-repo/acir/Cargo.toml b/acvm-repo/acir/Cargo.toml index 1f2ad1fb787..8139a58eefc 100644 --- a/acvm-repo/acir/Cargo.toml +++ b/acvm-repo/acir/Cargo.toml @@ -24,11 +24,11 @@ flate2.workspace = true bincode.workspace = true base64.workspace = true serde-big-array = "0.5.1" +strum = { workspace = true } +strum_macros = { workspace = true } [dev-dependencies] serde_json = "1.0" -strum = "0.24" -strum_macros = "0.24" serde-reflection = "0.3.6" serde-generate = "0.25.1" fxhash.workspace = true diff --git a/acvm-repo/acir/src/circuit/black_box_functions.rs b/acvm-repo/acir/src/circuit/black_box_functions.rs index 2e5a94f1c50..25842c14dbc 100644 --- a/acvm-repo/acir/src/circuit/black_box_functions.rs +++ b/acvm-repo/acir/src/circuit/black_box_functions.rs @@ -4,12 +4,10 @@ //! implemented in more basic constraints. use serde::{Deserialize, Serialize}; -#[cfg(test)] use strum_macros::EnumIter; #[allow(clippy::upper_case_acronyms)] -#[derive(Clone, Debug, Hash, Copy, PartialEq, Eq, Serialize, Deserialize)] -#[cfg_attr(test, derive(EnumIter))] +#[derive(Clone, Debug, Hash, Copy, PartialEq, Eq, Serialize, Deserialize, EnumIter)] pub enum BlackBoxFunc { /// Ciphers (encrypts) the provided plaintext using AES128 in CBC mode, /// padding the input using PKCS#7. diff --git a/acvm-repo/acir/src/circuit/brillig.rs b/acvm-repo/acir/src/circuit/brillig.rs index a9714ce29b2..ef75d088f8c 100644 --- a/acvm-repo/acir/src/circuit/brillig.rs +++ b/acvm-repo/acir/src/circuit/brillig.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; /// Inputs for the Brillig VM. These are the initial inputs /// that the Brillig VM will use to start. -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)] +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)] pub enum BrilligInputs { Single(Expression), Array(Vec>), @@ -14,7 +14,7 @@ pub enum BrilligInputs { /// Outputs for the Brillig VM. Once the VM has completed /// execution, this will be the object that is returned. -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)] +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)] pub enum BrilligOutputs { Simple(Witness), Array(Vec), @@ -23,7 +23,7 @@ pub enum BrilligOutputs { /// This is purely a wrapper struct around a list of Brillig opcode's which represents /// a full Brillig function to be executed by the Brillig VM. /// This is stored separately on a program and accessed through a [BrilligPointer]. -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default, Debug)] +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default, Debug, Hash)] pub struct BrilligBytecode { pub bytecode: Vec>, } diff --git a/acvm-repo/acir/src/circuit/mod.rs b/acvm-repo/acir/src/circuit/mod.rs index 33982065c2a..88605d3bdab 100644 --- a/acvm-repo/acir/src/circuit/mod.rs +++ b/acvm-repo/acir/src/circuit/mod.rs @@ -25,7 +25,7 @@ use self::{brillig::BrilligBytecode, opcodes::BlockId}; /// Bounded Expressions are useful if you are eventually going to pass the ACIR /// into a proving system which supports PLONK, where arithmetic expressions have a /// finite fan-in. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, Hash)] pub enum ExpressionWidth { #[default] Unbounded, @@ -36,13 +36,13 @@ pub enum ExpressionWidth { /// A program represented by multiple ACIR circuits. The execution trace of these /// circuits is dictated by construction of the [crate::native_types::WitnessStack]. -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default, Hash)] pub struct Program { pub functions: Vec>, pub unconstrained_functions: Vec>, } -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default, Hash)] pub struct Circuit { // current_witness_index is the highest witness index in the circuit. The next witness to be added to this circuit // will take on this value. (The value is cached here as an optimization.) @@ -69,13 +69,13 @@ pub struct Circuit { pub assert_messages: Vec<(OpcodeLocation, AssertionPayload)>, } -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub enum ExpressionOrMemory { Expression(Expression), Memory(BlockId), } -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct AssertionPayload { pub error_selector: u64, pub payload: Vec>, @@ -355,7 +355,7 @@ impl std::fmt::Debug for Program { } } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default, Hash)] pub struct PublicInputs(pub BTreeSet); impl PublicInputs { diff --git a/acvm-repo/acir/src/circuit/opcodes.rs b/acvm-repo/acir/src/circuit/opcodes.rs index 06effd3c5b6..f47c40b0dd7 100644 --- a/acvm-repo/acir/src/circuit/opcodes.rs +++ b/acvm-repo/acir/src/circuit/opcodes.rs @@ -15,7 +15,7 @@ pub use black_box_function_call::{ }; pub use memory_operation::{BlockId, MemOp}; -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub enum BlockType { Memory, CallData(u32), @@ -29,7 +29,7 @@ impl BlockType { } #[allow(clippy::large_enum_variant)] -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub enum Opcode { /// An `AssertZero` opcode adds the constraint that `P(w) = 0`, where /// `w=(w_1,..w_n)` is a tuple of `n` witnesses, and `P` is a multi-variate diff --git a/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs b/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs index fa51caf5155..e756eedefbc 100644 --- a/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs +++ b/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs @@ -9,13 +9,13 @@ use thiserror::Error; // Note: Some functions will not use all of the witness // So we need to supply how many bits of the witness is needed -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] pub enum ConstantOrWitnessEnum { Constant(F), Witness(Witness), } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct FunctionInput { input: ConstantOrWitnessEnum, num_bits: u32, @@ -79,7 +79,7 @@ impl std::fmt::Display for FunctionInput { } } -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub enum BlackBoxFuncCall { AES128Encrypt { inputs: Vec>, diff --git a/acvm-repo/acir/src/circuit/opcodes/memory_operation.rs b/acvm-repo/acir/src/circuit/opcodes/memory_operation.rs index 90e3ee0563a..c9a78983204 100644 --- a/acvm-repo/acir/src/circuit/opcodes/memory_operation.rs +++ b/acvm-repo/acir/src/circuit/opcodes/memory_operation.rs @@ -7,7 +7,7 @@ pub struct BlockId(pub u32); /// Operation on a block of memory /// We can either write or read at an index in memory -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)] +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)] pub struct MemOp { /// A constant expression that can be 0 (read) or 1 (write) pub operation: Expression, diff --git a/acvm-repo/acir_field/src/field_element.rs b/acvm-repo/acir_field/src/field_element.rs index 47ceb903111..0249b410aa7 100644 --- a/acvm-repo/acir_field/src/field_element.rs +++ b/acvm-repo/acir_field/src/field_element.rs @@ -9,7 +9,7 @@ use crate::AcirField; // XXX: Switch out for a trait and proper implementations // This implementation is inefficient, can definitely remove hex usage and Iterator instances for trivial functionality -#[derive(Default, Clone, Copy, Eq, PartialOrd, Ord)] +#[derive(Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct FieldElement(F); impl std::fmt::Display for FieldElement { @@ -43,18 +43,6 @@ impl std::fmt::Debug for FieldElement { } } -impl std::hash::Hash for FieldElement { - fn hash(&self, state: &mut H) { - state.write(&self.to_be_bytes()); - } -} - -impl PartialEq for FieldElement { - fn eq(&self, other: &Self) -> bool { - self.to_be_bytes() == other.to_be_bytes() - } -} - impl From for FieldElement { fn from(mut a: i128) -> FieldElement { let mut negative = false; @@ -158,23 +146,6 @@ impl FieldElement { let fr = F::from_str(input).ok()?; Some(FieldElement(fr)) } - - fn bits(&self) -> Vec { - fn byte_to_bit(byte: u8) -> Vec { - let mut bits = Vec::with_capacity(8); - for index in (0..=7).rev() { - bits.push((byte & (1 << index)) >> index == 1); - } - bits - } - - let bytes = self.to_be_bytes(); - let mut bits = Vec::with_capacity(bytes.len() * 8); - for byte in bytes { - bits.extend(byte_to_bit(byte)); - } - bits - } } impl AcirField for FieldElement { @@ -224,12 +195,26 @@ impl AcirField for FieldElement { /// This is the number of bits required to represent this specific field element fn num_bits(&self) -> u32 { - let bits = self.bits(); - // Iterate the number of bits and pop off all leading zeroes - let iter = bits.iter().skip_while(|x| !(**x)); + let bytes = self.to_be_bytes(); + + // Iterate through the byte decomposition and pop off all leading zeroes + let mut iter = bytes.iter().skip_while(|x| (**x) == 0); + + // The first non-zero byte in the decomposition may have some leading zero-bits. + let Some(head_byte) = iter.next() else { + // If we don't have a non-zero byte then the field element is zero, + // which we consider to require a single bit to represent. + return 1; + }; + let num_bits_for_head_byte = head_byte.ilog2(); + + // Each remaining byte in the byte decomposition requires 8 bits. + // // Note: count will panic if it goes over usize::MAX. // This may not be suitable for devices whose usize < u16 - iter.count() as u32 + let tail_length = iter.count() as u32; + + 8 * tail_length + num_bits_for_head_byte + 1 } fn to_u128(self) -> u128 { @@ -374,6 +359,30 @@ mod tests { use super::{AcirField, FieldElement}; use proptest::prelude::*; + #[test] + fn requires_one_bit_to_hold_zero() { + let field = FieldElement::::zero(); + assert_eq!(field.num_bits(), 1); + } + + proptest! { + #[test] + fn num_bits_agrees_with_ilog2(num in 1u128..) { + let field = FieldElement::::from(num); + prop_assert_eq!(field.num_bits(), num.ilog2() + 1); + } + } + + #[test] + fn test_fits_in_u128() { + let field = FieldElement::::from(u128::MAX); + assert_eq!(field.num_bits(), 128); + assert!(field.fits_in_u128()); + let big_field = field + FieldElement::one(); + assert_eq!(big_field.num_bits(), 129); + assert!(!big_field.fits_in_u128()); + } + #[test] fn serialize_fixed_test_vectors() { // Serialized field elements from of 0, -1, -2, -3 diff --git a/acvm-repo/acvm/Cargo.toml b/acvm-repo/acvm/Cargo.toml index 6ebfed27a32..e513ae4e727 100644 --- a/acvm-repo/acvm/Cargo.toml +++ b/acvm-repo/acvm/Cargo.toml @@ -25,11 +25,7 @@ acvm_blackbox_solver.workspace = true indexmap = "1.7.0" [features] -bn254 = [ - "acir/bn254", - "brillig_vm/bn254", - "acvm_blackbox_solver/bn254", -] +bn254 = ["acir/bn254", "brillig_vm/bn254", "acvm_blackbox_solver/bn254"] bls12_381 = [ "acir/bls12_381", "brillig_vm/bls12_381", @@ -37,10 +33,11 @@ bls12_381 = [ ] [dev-dependencies] -ark-bls12-381 = { version = "^0.4.0", default-features = false, features = ["curve"] } +ark-bls12-381 = { version = "^0.4.0", default-features = false, features = [ + "curve", +] } ark-bn254.workspace = true bn254_blackbox_solver.workspace = true proptest.workspace = true zkhash = { version = "^0.2.0", default-features = false } num-bigint.workspace = true - diff --git a/acvm-repo/blackbox_solver/src/bigint.rs b/acvm-repo/blackbox_solver/src/bigint.rs index b8bc9dc0d70..540862843ab 100644 --- a/acvm-repo/blackbox_solver/src/bigint.rs +++ b/acvm-repo/blackbox_solver/src/bigint.rs @@ -97,3 +97,51 @@ impl BigIntSolver { Ok(()) } } + +/// Wrapper over the generic bigint solver to automatically assign bigint IDs. +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct BigIntSolverWithId { + solver: BigIntSolver, + last_id: u32, +} + +impl BigIntSolverWithId { + pub fn create_bigint_id(&mut self) -> u32 { + let output = self.last_id; + self.last_id += 1; + output + } + + pub fn bigint_from_bytes( + &mut self, + inputs: &[u8], + modulus: &[u8], + ) -> Result { + let id = self.create_bigint_id(); + self.solver.bigint_from_bytes(inputs, modulus, id)?; + Ok(id) + } + + pub fn bigint_to_bytes(&self, input: u32) -> Result, BlackBoxResolutionError> { + self.solver.bigint_to_bytes(input) + } + + pub fn bigint_op( + &mut self, + lhs: u32, + rhs: u32, + func: BlackBoxFunc, + ) -> Result { + let modulus_lhs = self.solver.get_modulus(lhs, func)?; + let modulus_rhs = self.solver.get_modulus(rhs, func)?; + if modulus_lhs != modulus_rhs { + return Err(BlackBoxResolutionError::Failed( + func, + "moduli should be identical in BigInt operation".to_string(), + )); + } + let id = self.create_bigint_id(); + self.solver.bigint_op(lhs, rhs, id, func)?; + Ok(id) + } +} diff --git a/acvm-repo/blackbox_solver/src/lib.rs b/acvm-repo/blackbox_solver/src/lib.rs index d8f926fcb4b..0fa56c2f531 100644 --- a/acvm-repo/blackbox_solver/src/lib.rs +++ b/acvm-repo/blackbox_solver/src/lib.rs @@ -18,7 +18,7 @@ mod hash; mod logic; pub use aes128::aes128_encrypt; -pub use bigint::BigIntSolver; +pub use bigint::{BigIntSolver, BigIntSolverWithId}; pub use curve_specific_solver::{BlackBoxFunctionSolver, StubbedBlackBoxSolver}; pub use ecdsa::{ecdsa_secp256k1_verify, ecdsa_secp256r1_verify}; pub use hash::{blake2s, blake3, keccakf1600, sha256_compression}; diff --git a/acvm-repo/brillig/src/black_box.rs b/acvm-repo/brillig/src/black_box.rs index 3264388c8ef..cbb268c0a50 100644 --- a/acvm-repo/brillig/src/black_box.rs +++ b/acvm-repo/brillig/src/black_box.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; /// These opcodes provide an equivalent of ACIR blackbox functions. /// They are implemented as native functions in the VM. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash)] pub enum BlackBoxOp { /// Encrypts a message using AES128. AES128Encrypt { diff --git a/acvm-repo/brillig/src/opcodes.rs b/acvm-repo/brillig/src/opcodes.rs index 8b72b5a9b41..1cb31ca3d0a 100644 --- a/acvm-repo/brillig/src/opcodes.rs +++ b/acvm-repo/brillig/src/opcodes.rs @@ -56,7 +56,7 @@ impl MemoryAddress { } /// Describes the memory layout for an array/vector element -#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize, Hash)] pub enum HeapValueType { // A single field element is enough to represent the value with a given bit size Simple(BitSize), @@ -81,7 +81,7 @@ impl HeapValueType { } /// A fixed-sized array starting from a Brillig memory location. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Copy)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Copy, Hash)] pub struct HeapArray { pub pointer: MemoryAddress, pub size: usize, @@ -94,13 +94,13 @@ impl Default for HeapArray { } /// A memory-sized vector passed starting from a Brillig memory location and with a memory-held size -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Copy)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Copy, Hash)] pub struct HeapVector { pub pointer: MemoryAddress, pub size: MemoryAddress, } -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Copy, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Copy, PartialOrd, Ord, Hash)] pub enum IntegerBitSize { U1, U8, @@ -152,7 +152,7 @@ impl std::fmt::Display for IntegerBitSize { } } -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Copy, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Copy, PartialOrd, Ord, Hash)] pub enum BitSize { Field, Integer(IntegerBitSize), @@ -181,7 +181,7 @@ impl BitSize { /// While we are usually agnostic to how memory is passed within Brillig, /// this needs to be encoded somehow when dealing with an external system. /// For simplicity, the extra type information is given right in the ForeignCall instructions. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Copy)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Copy, Hash)] pub enum ValueOrArray { /// A single value passed to or from an external call /// It is an 'immediate' value - used without dereferencing. @@ -198,7 +198,7 @@ pub enum ValueOrArray { HeapVector(HeapVector), } -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub enum BrilligOpcode { /// Takes the fields in addresses `lhs` and `rhs` /// Performs the specified binary operation @@ -314,7 +314,7 @@ pub enum BrilligOpcode { } /// Binary fixed-length field expressions -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash)] pub enum BinaryFieldOp { Add, Sub, @@ -332,7 +332,7 @@ pub enum BinaryFieldOp { } /// Binary fixed-length integer expressions -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash)] pub enum BinaryIntOp { Add, Sub, diff --git a/acvm-repo/brillig_vm/src/black_box.rs b/acvm-repo/brillig_vm/src/black_box.rs index 0d90a4c8502..19e2dd7553d 100644 --- a/acvm-repo/brillig_vm/src/black_box.rs +++ b/acvm-repo/brillig_vm/src/black_box.rs @@ -1,9 +1,8 @@ use acir::brillig::{BlackBoxOp, HeapArray, HeapVector, IntegerBitSize}; use acir::{AcirField, BlackBoxFunc}; -use acvm_blackbox_solver::BigIntSolver; use acvm_blackbox_solver::{ aes128_encrypt, blake2s, blake3, ecdsa_secp256k1_verify, ecdsa_secp256r1_verify, keccakf1600, - sha256_compression, BlackBoxFunctionSolver, BlackBoxResolutionError, + sha256_compression, BigIntSolverWithId, BlackBoxFunctionSolver, BlackBoxResolutionError, }; use num_bigint::BigUint; use num_traits::Zero; @@ -39,11 +38,13 @@ fn to_value_vec(input: &[u8]) -> Vec> { input.iter().map(|&x| x.into()).collect() } +pub(crate) type BrilligBigIntSolver = BigIntSolverWithId; + pub(crate) fn evaluate_black_box>( op: &BlackBoxOp, solver: &Solver, memory: &mut Memory, - bigint_solver: &mut BrilligBigintSolver, + bigint_solver: &mut BrilligBigIntSolver, ) -> Result<(), BlackBoxResolutionError> { match op { BlackBoxOp::AES128Encrypt { inputs, iv, key, outputs } => { @@ -56,7 +57,7 @@ pub(crate) fn evaluate_black_box })?; let key: [u8; 16] = to_u8_vec(read_heap_array(memory, key)).try_into().map_err(|_| { - BlackBoxResolutionError::Failed(bb_func, "Invalid ley length".to_string()) + BlackBoxResolutionError::Failed(bb_func, "Invalid key length".to_string()) })?; let ciphertext = aes128_encrypt(&inputs, iv, key)?; @@ -353,54 +354,6 @@ pub(crate) fn evaluate_black_box } } -/// Wrapper over the generic bigint solver to automatically assign bigint ids in brillig -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub(crate) struct BrilligBigintSolver { - bigint_solver: BigIntSolver, - last_id: u32, -} - -impl BrilligBigintSolver { - pub(crate) fn create_bigint_id(&mut self) -> u32 { - let output = self.last_id; - self.last_id += 1; - output - } - - pub(crate) fn bigint_from_bytes( - &mut self, - inputs: &[u8], - modulus: &[u8], - ) -> Result { - let id = self.create_bigint_id(); - self.bigint_solver.bigint_from_bytes(inputs, modulus, id)?; - Ok(id) - } - - pub(crate) fn bigint_to_bytes(&self, input: u32) -> Result, BlackBoxResolutionError> { - self.bigint_solver.bigint_to_bytes(input) - } - - pub(crate) fn bigint_op( - &mut self, - lhs: u32, - rhs: u32, - func: BlackBoxFunc, - ) -> Result { - let modulus_lhs = self.bigint_solver.get_modulus(lhs, func)?; - let modulus_rhs = self.bigint_solver.get_modulus(rhs, func)?; - if modulus_lhs != modulus_rhs { - return Err(BlackBoxResolutionError::Failed( - func, - "moduli should be identical in BigInt operation".to_string(), - )); - } - let id = self.create_bigint_id(); - self.bigint_solver.bigint_op(lhs, rhs, id, func)?; - Ok(id) - } -} - fn black_box_function_from_op(op: &BlackBoxOp) -> BlackBoxFunc { match op { BlackBoxOp::AES128Encrypt { .. } => BlackBoxFunc::AES128Encrypt, diff --git a/acvm-repo/brillig_vm/src/lib.rs b/acvm-repo/brillig_vm/src/lib.rs index 45025fbb208..5b3688339b5 100644 --- a/acvm-repo/brillig_vm/src/lib.rs +++ b/acvm-repo/brillig_vm/src/lib.rs @@ -17,7 +17,7 @@ use acir::brillig::{ use acir::AcirField; use acvm_blackbox_solver::BlackBoxFunctionSolver; use arithmetic::{evaluate_binary_field_op, evaluate_binary_int_op, BrilligArithmeticError}; -use black_box::{evaluate_black_box, BrilligBigintSolver}; +use black_box::{evaluate_black_box, BrilligBigIntSolver}; // Re-export `brillig`. pub use acir::brillig; @@ -95,7 +95,7 @@ pub struct VM<'a, F, B: BlackBoxFunctionSolver> { /// The solver for blackbox functions black_box_solver: &'a B, // The solver for big integers - bigint_solver: BrilligBigintSolver, + bigint_solver: BrilligBigIntSolver, // Flag that determines whether we want to profile VM. profiling_active: bool, // Samples for profiling the VM execution. diff --git a/compiler/noirc_driver/src/debug.rs b/compiler/noirc_driver/src/debug.rs index f5eaede89b2..6044e6c0e65 100644 --- a/compiler/noirc_driver/src/debug.rs +++ b/compiler/noirc_driver/src/debug.rs @@ -8,7 +8,7 @@ use std::{ /// For a given file, we store the source code and the path to the file /// so consumers of the debug artifact can reconstruct the original source code structure. -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, Hash)] pub struct DebugFile { pub source: String, pub path: PathBuf, diff --git a/compiler/noirc_driver/src/lib.rs b/compiler/noirc_driver/src/lib.rs index 72ea464805f..5bedefaf563 100644 --- a/compiler/noirc_driver/src/lib.rs +++ b/compiler/noirc_driver/src/lib.rs @@ -13,7 +13,7 @@ use noirc_abi::{AbiParameter, AbiType, AbiValue}; use noirc_errors::{CustomDiagnostic, FileDiagnostic}; use noirc_evaluator::create_program; use noirc_evaluator::errors::RuntimeError; -use noirc_evaluator::ssa::SsaProgramArtifact; +use noirc_evaluator::ssa::{SsaLogging, SsaProgramArtifact}; use noirc_frontend::debug::build_debug_crate_file; use noirc_frontend::hir::def_map::{Contract, CrateDefMap}; use noirc_frontend::hir::Context; @@ -70,6 +70,11 @@ pub struct CompileOptions { #[arg(long, hide = true)] pub show_ssa: bool, + /// Only show SSA passes whose name contains the provided string. + /// This setting takes precedence over `show_ssa` if it's not empty. + #[arg(long, hide = true)] + pub show_ssa_pass_name: Option, + /// Emit the unoptimized SSA IR to file. /// The IR will be dumped into the workspace target directory, /// under `[compiled-package].ssa.json`. @@ -126,11 +131,19 @@ pub struct CompileOptions { #[arg(long)] pub skip_underconstrained_check: bool, - /// Setting to decide on an inlining strategy for brillig functions. + /// Setting to decide on an inlining strategy for Brillig functions. /// A more aggressive inliner should generate larger programs but more optimized /// A less aggressive inliner should generate smaller programs #[arg(long, hide = true, allow_hyphen_values = true, default_value_t = i64::MAX)] pub inliner_aggressiveness: i64, + + /// Setting the maximum acceptable increase in Brillig bytecode size due to + /// unrolling small loops. When left empty, any change is accepted as long + /// as it required fewer SSA instructions. + /// A higher value results in fewer jumps but a larger program. + /// A lower value keeps the original program if it was smaller, even if it has more jumps. + #[arg(long, hide = true, allow_hyphen_values = true)] + pub max_bytecode_increase_percent: Option, } pub fn parse_expression_width(input: &str) -> Result { @@ -321,6 +334,8 @@ pub fn compute_function_abi( /// /// On success this returns the compiled program alongside any warnings that were found. /// On error this returns the non-empty list of warnings and errors. +/// +/// See [compile_no_check] for further information about the use of `cached_program`. pub fn compile_main( context: &mut Context, crate_id: CrateId, @@ -542,6 +557,15 @@ pub const DEFAULT_EXPRESSION_WIDTH: ExpressionWidth = ExpressionWidth::Bounded { /// Compile the current crate using `main_function` as the entrypoint. /// /// This function assumes [`check_crate`] is called beforehand. +/// +/// If the program is not returned from cache, it is backend-agnostic and must go through a transformation +/// pass before usage in proof generation; if it's returned from cache these transformations might have +/// already been applied. +/// +/// The transformations are _not_ covered by the check that decides whether we can use the cached artifact. +/// That comparison is based on on [CompiledProgram::hash] which is a persisted version of the hash of the input +/// [`ast::Program`][noirc_frontend::monomorphization::ast::Program], whereas the output [`circuit::Program`][acir::circuit::Program] +/// contains the final optimized ACIR opcodes, including the transformation done after this compilation. #[tracing::instrument(level = "trace", skip_all, fields(function_name = context.function_name(&main_function)))] pub fn compile_no_check( context: &mut Context, @@ -556,8 +580,6 @@ pub fn compile_no_check( monomorphize(main_function, &mut context.def_interner)? }; - let hash = fxhash::hash64(&program); - let hashes_match = cached_program.as_ref().map_or(false, |program| program.hash == hash); if options.show_monomorphized { println!("{program}"); } @@ -571,13 +593,28 @@ pub fn compile_no_check( || options.show_ssa || options.emit_ssa; - if !force_compile && hashes_match { - info!("Program matches existing artifact, returning early"); - return Ok(cached_program.expect("cache must exist for hashes to match")); + // Hash the AST program, which is going to be used to fingerprint the compilation artifact. + let hash = fxhash::hash64(&program); + + if let Some(cached_program) = cached_program { + if !force_compile && cached_program.hash == hash { + info!("Program matches existing artifact, returning early"); + return Ok(cached_program); + } } + let return_visibility = program.return_visibility; let ssa_evaluator_options = noirc_evaluator::ssa::SsaEvaluatorOptions { - enable_ssa_logging: options.show_ssa, + ssa_logging: match &options.show_ssa_pass_name { + Some(string) => SsaLogging::Contains(string.clone()), + None => { + if options.show_ssa { + SsaLogging::All + } else { + SsaLogging::None + } + } + }, enable_brillig_logging: options.show_brillig, force_brillig_output: options.force_brillig, print_codegen_timings: options.benchmark_codegen, @@ -589,6 +626,7 @@ pub fn compile_no_check( emit_ssa: if options.emit_ssa { Some(context.package_build_path.clone()) } else { None }, skip_underconstrained_check: options.skip_underconstrained_check, inliner_aggressiveness: options.inliner_aggressiveness, + max_bytecode_increase_percent: options.max_bytecode_increase_percent, }; let SsaProgramArtifact { program, debug, warnings, names, brillig_names, error_types, .. } = diff --git a/compiler/noirc_driver/src/program.rs b/compiler/noirc_driver/src/program.rs index 88460482928..4b4d6662e8e 100644 --- a/compiler/noirc_driver/src/program.rs +++ b/compiler/noirc_driver/src/program.rs @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; use super::debug::DebugFile; -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone, Hash)] pub struct CompiledProgram { pub noir_version: String, /// Hash of the [`Program`][noirc_frontend::monomorphization::ast::Program] from which this [`CompiledProgram`] diff --git a/compiler/noirc_errors/src/debug_info.rs b/compiler/noirc_errors/src/debug_info.rs index 77028f739bd..a5e12b37712 100644 --- a/compiler/noirc_errors/src/debug_info.rs +++ b/compiler/noirc_errors/src/debug_info.rs @@ -94,7 +94,7 @@ impl ProgramDebugInfo { } #[serde_as] -#[derive(Default, Debug, Clone, Deserialize, Serialize)] +#[derive(Default, Debug, Clone, Deserialize, Serialize, Hash)] pub struct DebugInfo { /// Map opcode index of an ACIR circuit into the source code location /// Serde does not support mapping keys being enums for json, so we indicate diff --git a/compiler/noirc_errors/src/position.rs b/compiler/noirc_errors/src/position.rs index 8131db323b9..c7a64c4f422 100644 --- a/compiler/noirc_errors/src/position.rs +++ b/compiler/noirc_errors/src/position.rs @@ -94,8 +94,10 @@ impl Span { self.start() <= other.start() && self.end() >= other.end() } + /// Returns `true` if any point of `self` intersects a point of `other`. + /// Adjacent spans are considered to intersect (so, for example, `0..1` intersects `1..3`). pub fn intersects(&self, other: &Span) -> bool { - self.end() > other.start() && self.start() < other.end() + self.end() >= other.start() && self.start() <= other.end() } pub fn is_smaller(&self, other: &Span) -> bool { @@ -140,3 +142,37 @@ impl Location { self.file == other.file && self.span.contains(&other.span) } } + +#[cfg(test)] +mod tests { + use crate::Span; + + #[test] + fn test_intersects() { + assert!(Span::from(5..10).intersects(&Span::from(5..10))); + + assert!(Span::from(5..10).intersects(&Span::from(5..5))); + assert!(Span::from(5..5).intersects(&Span::from(5..10))); + + assert!(Span::from(10..10).intersects(&Span::from(5..10))); + assert!(Span::from(5..10).intersects(&Span::from(10..10))); + + assert!(Span::from(5..10).intersects(&Span::from(6..9))); + assert!(Span::from(6..9).intersects(&Span::from(5..10))); + + assert!(Span::from(5..10).intersects(&Span::from(4..11))); + assert!(Span::from(4..11).intersects(&Span::from(5..10))); + + assert!(Span::from(5..10).intersects(&Span::from(4..6))); + assert!(Span::from(4..6).intersects(&Span::from(5..10))); + + assert!(Span::from(5..10).intersects(&Span::from(9..11))); + assert!(Span::from(9..11).intersects(&Span::from(5..10))); + + assert!(!Span::from(5..10).intersects(&Span::from(3..4))); + assert!(!Span::from(3..4).intersects(&Span::from(5..10))); + + assert!(!Span::from(5..10).intersects(&Span::from(11..12))); + assert!(!Span::from(11..12).intersects(&Span::from(5..10))); + } +} diff --git a/compiler/noirc_evaluator/Cargo.toml b/compiler/noirc_evaluator/Cargo.toml index e25b5bf855a..bb8c62cfd95 100644 --- a/compiler/noirc_evaluator/Cargo.toml +++ b/compiler/noirc_evaluator/Cargo.toml @@ -33,6 +33,7 @@ cfg-if.workspace = true proptest.workspace = true similar-asserts.workspace = true num-traits.workspace = true +test-case.workspace = true [features] bn254 = ["noirc_frontend/bn254"] diff --git a/compiler/noirc_evaluator/src/acir/acir_variable.rs b/compiler/noirc_evaluator/src/acir/acir_variable.rs index 6ba072f01a4..9f2c649ee3e 100644 --- a/compiler/noirc_evaluator/src/acir/acir_variable.rs +++ b/compiler/noirc_evaluator/src/acir/acir_variable.rs @@ -92,7 +92,7 @@ impl<'a> From<&'a SsaType> for AcirType { SsaType::Numeric(numeric_type) => AcirType::NumericType(*numeric_type), SsaType::Array(elements, size) => { let elements = elements.iter().map(|e| e.into()).collect(); - AcirType::Array(elements, *size) + AcirType::Array(elements, *size as usize) } _ => unreachable!("The type {value} cannot be represented in ACIR"), } @@ -578,7 +578,7 @@ impl> AcirContext { let numeric_type = match typ { AcirType::NumericType(numeric_type) => numeric_type, AcirType::Array(_, _) => { - todo!("cannot divide arrays. This should have been caught by the frontend") + unreachable!("cannot divide arrays. This should have been caught by the frontend") } }; match numeric_type { @@ -1084,11 +1084,22 @@ impl> AcirContext { &mut self, lhs: AcirVar, rhs: AcirVar, + typ: AcirType, bit_size: u32, predicate: AcirVar, ) -> Result { - let (_, remainder) = self.euclidean_division_var(lhs, rhs, bit_size, predicate)?; - Ok(remainder) + let numeric_type = match typ { + AcirType::NumericType(numeric_type) => numeric_type, + AcirType::Array(_, _) => { + unreachable!("cannot modulo arrays. This should have been caught by the frontend") + } + }; + + let (_, remainder_var) = match numeric_type { + NumericType::Signed { bit_size } => self.signed_division_var(lhs, rhs, bit_size)?, + _ => self.euclidean_division_var(lhs, rhs, bit_size, predicate)?, + }; + Ok(remainder_var) } /// Constrains the `AcirVar` variable to be of type `NumericType`. diff --git a/compiler/noirc_evaluator/src/acir/mod.rs b/compiler/noirc_evaluator/src/acir/mod.rs index 7274fe908d1..76f0dea95bb 100644 --- a/compiler/noirc_evaluator/src/acir/mod.rs +++ b/compiler/noirc_evaluator/src/acir/mod.rs @@ -571,7 +571,7 @@ impl<'a> Context<'a> { AcirValue::Array(_) => { let block_id = self.block_id(param_id); let len = if matches!(typ, Type::Array(_, _)) { - typ.flattened_size() + typ.flattened_size() as usize } else { return Err(InternalError::Unexpected { expected: "Block params should be an array".to_owned(), @@ -816,7 +816,9 @@ impl<'a> Context<'a> { let inputs = vecmap(arguments, |arg| self.convert_value(*arg, dfg)); let output_count = result_ids .iter() - .map(|result_id| dfg.type_of_value(*result_id).flattened_size()) + .map(|result_id| { + dfg.type_of_value(*result_id).flattened_size() as usize + }) .sum(); let Some(acir_function_id) = @@ -948,7 +950,7 @@ impl<'a> Context<'a> { let block_id = self.block_id(&array_id); let array_typ = dfg.type_of_value(array_id); let len = if matches!(array_typ, Type::Array(_, _)) { - array_typ.flattened_size() + array_typ.flattened_size() as usize } else { Self::flattened_value_size(&output) }; @@ -1444,7 +1446,7 @@ impl<'a> Context<'a> { // a separate SSA value and restrictions on slice indices should be generated elsewhere in the SSA. let array_typ = dfg.type_of_value(array); let array_len = if !array_typ.contains_slice_element() { - array_typ.flattened_size() + array_typ.flattened_size() as usize } else { self.flattened_slice_size(array, dfg) }; @@ -1539,7 +1541,7 @@ impl<'a> Context<'a> { let value = self.convert_value(array, dfg); let array_typ = dfg.type_of_value(array); let len = if !array_typ.contains_slice_element() { - array_typ.flattened_size() + array_typ.flattened_size() as usize } else { self.flattened_slice_size(array, dfg) }; @@ -1810,7 +1812,7 @@ impl<'a> Context<'a> { return_values .iter() - .fold(0, |acc, value_id| acc + dfg.type_of_value(*value_id).flattened_size()) + .fold(0, |acc, value_id| acc + dfg.type_of_value(*value_id).flattened_size() as usize) } /// Converts an SSA terminator's return values into their ACIR representations @@ -1968,6 +1970,7 @@ impl<'a> Context<'a> { BinaryOp::Mod => self.acir_context.modulo_var( lhs, rhs, + binary_type.clone(), bit_count, self.current_side_effects_enabled_var, ), @@ -2155,7 +2158,7 @@ impl<'a> Context<'a> { let inputs = vecmap(&arguments_no_slice_len, |arg| self.convert_value(*arg, dfg)); let output_count = result_ids.iter().fold(0usize, |sum, result_id| { - sum + dfg.try_get_array_length(*result_id).unwrap_or(1) + sum + dfg.try_get_array_length(*result_id).unwrap_or(1) as usize }); let vars = self.acir_context.black_box_function(black_box, inputs, output_count)?; @@ -2179,7 +2182,7 @@ impl<'a> Context<'a> { endian, field, radix, - array_length as u32, + array_length, result_type[0].clone().into(), ) .map(|array| vec![array]) @@ -2193,12 +2196,7 @@ impl<'a> Context<'a> { }; self.acir_context - .bit_decompose( - endian, - field, - array_length as u32, - result_type[0].clone().into(), - ) + .bit_decompose(endian, field, array_length, result_type[0].clone().into()) .map(|array| vec![array]) } Intrinsic::ArrayLen => { @@ -2219,7 +2217,7 @@ impl<'a> Context<'a> { let acir_value = self.convert_value(slice_contents, dfg); let array_len = if !slice_typ.contains_slice_element() { - slice_typ.flattened_size() + slice_typ.flattened_size() as usize } else { self.flattened_slice_size(slice_contents, dfg) }; diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs index 1fa4985295a..9c88c559b59 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs @@ -1823,7 +1823,7 @@ impl<'block> BrilligBlock<'block> { Type::Array(_, nested_size) => { let inner_array = BrilligArray { pointer: self.brillig_context.allocate_register(), - size: *nested_size, + size: *nested_size as usize, }; self.allocate_foreign_call_result_array(element_type, inner_array); diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block_variables.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block_variables.rs index 393d4c967c2..bf0a1bc7347 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block_variables.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block_variables.rs @@ -142,7 +142,7 @@ pub(crate) fn allocate_value( } Type::Array(item_typ, elem_count) => BrilligVariable::BrilligArray(BrilligArray { pointer: brillig_context.allocate_register(), - size: compute_array_length(&item_typ, elem_count), + size: compute_array_length(&item_typ, elem_count as usize), }), Type::Slice(_) => BrilligVariable::BrilligVector(BrilligVector { pointer: brillig_context.allocate_register(), diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs index 2779be103cd..3dea7b3e7f5 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs @@ -59,7 +59,7 @@ impl FunctionContext { vecmap(item_type.iter(), |item_typ| { FunctionContext::ssa_type_to_parameter(item_typ) }), - *size, + *size as usize, ), Type::Slice(_) => { panic!("ICE: Slice parameters cannot be derived from type information") diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir/brillig_variable.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir/brillig_variable.rs index 81d61e05cc4..0bb18448670 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir/brillig_variable.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir/brillig_variable.rs @@ -88,7 +88,7 @@ pub(crate) fn type_to_heap_value_type(typ: &Type) -> HeapValueType { ), Type::Array(elem_type, size) => HeapValueType::Array { value_types: elem_type.as_ref().iter().map(type_to_heap_value_type).collect(), - size: typ.element_size() * size, + size: typ.element_size() * *size as usize, }, Type::Slice(elem_type) => HeapValueType::Vector { value_types: elem_type.as_ref().iter().map(type_to_heap_value_type).collect(), diff --git a/compiler/noirc_evaluator/src/brillig/mod.rs b/compiler/noirc_evaluator/src/brillig/mod.rs index 1b61ae1a864..cb8c35cd8e0 100644 --- a/compiler/noirc_evaluator/src/brillig/mod.rs +++ b/compiler/noirc_evaluator/src/brillig/mod.rs @@ -12,7 +12,7 @@ use self::{ }, }; use crate::ssa::{ - ir::function::{Function, FunctionId, RuntimeType}, + ir::function::{Function, FunctionId}, ssa_gen::Ssa, }; use fxhash::FxHashMap as HashMap; @@ -59,7 +59,7 @@ impl std::ops::Index for Brillig { } impl Ssa { - /// Compile to brillig brillig functions and ACIR functions reachable from them + /// Compile Brillig functions and ACIR functions reachable from them #[tracing::instrument(level = "trace", skip_all)] pub(crate) fn to_brillig(&self, enable_debug_trace: bool) -> Brillig { // Collect all the function ids that are reachable from brillig @@ -67,9 +67,7 @@ impl Ssa { let brillig_reachable_function_ids = self .functions .iter() - .filter_map(|(id, func)| { - matches!(func.runtime(), RuntimeType::Brillig(_)).then_some(*id) - }) + .filter_map(|(id, func)| func.runtime().is_brillig().then_some(*id)) .collect::>(); let mut brillig = Brillig::default(); diff --git a/compiler/noirc_evaluator/src/errors.rs b/compiler/noirc_evaluator/src/errors.rs index 994e97eabb8..75a3ceb3a72 100644 --- a/compiler/noirc_evaluator/src/errors.rs +++ b/compiler/noirc_evaluator/src/errors.rs @@ -63,7 +63,7 @@ pub enum RuntimeError { UnknownReference { call_stack: CallStack }, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Hash)] pub enum SsaReport { Warning(InternalWarning), Bug(InternalBug), @@ -107,7 +107,7 @@ impl From for FileDiagnostic { } } -#[derive(Debug, PartialEq, Eq, Clone, Error, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Eq, Clone, Error, Serialize, Deserialize, Hash)] pub enum InternalWarning { #[error("Return variable contains a constant value")] ReturnConstant { call_stack: CallStack }, @@ -115,7 +115,7 @@ pub enum InternalWarning { VerifyProof { call_stack: CallStack }, } -#[derive(Debug, PartialEq, Eq, Clone, Error, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Eq, Clone, Error, Serialize, Deserialize, Hash)] pub enum InternalBug { #[error("Input to brillig function is in a separate subgraph to output")] IndependentSubgraph { call_stack: CallStack }, diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 97c1760d87c..8f31023f790 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -44,9 +44,16 @@ mod opt; pub(crate) mod parser; pub mod ssa_gen; +#[derive(Debug, Clone)] +pub enum SsaLogging { + None, + All, + Contains(String), +} + pub struct SsaEvaluatorOptions { /// Emit debug information for the intermediate SSA IR - pub enable_ssa_logging: bool, + pub ssa_logging: SsaLogging, pub enable_brillig_logging: bool, @@ -67,6 +74,11 @@ pub struct SsaEvaluatorOptions { /// The higher the value, the more inlined brillig functions will be. pub inliner_aggressiveness: i64, + + /// Maximum accepted percentage increase in the Brillig bytecode size after unrolling loops. + /// When `None` the size increase check is skipped altogether and any decrease in the SSA + /// instruction count is accepted. + pub max_bytecode_increase_percent: Option, } pub(crate) struct ArtifactsAndWarnings(Artifacts, Vec); @@ -85,46 +97,49 @@ pub(crate) fn optimize_into_acir( let mut ssa = SsaBuilder::new( program, - options.enable_ssa_logging, + options.ssa_logging.clone(), options.force_brillig_output, options.print_codegen_timings, &options.emit_ssa, )? - .run_pass(Ssa::defunctionalize, "After Defunctionalization:") - .run_pass(Ssa::remove_paired_rc, "After Removing Paired rc_inc & rc_decs:") - .run_pass(Ssa::separate_runtime, "After Runtime Separation:") - .run_pass(Ssa::resolve_is_unconstrained, "After Resolving IsUnconstrained:") - .run_pass(|ssa| ssa.inline_functions(options.inliner_aggressiveness), "After Inlining (1st):") + .run_pass(Ssa::defunctionalize, "Defunctionalization") + .run_pass(Ssa::remove_paired_rc, "Removing Paired rc_inc & rc_decs") + .run_pass(Ssa::separate_runtime, "Runtime Separation") + .run_pass(Ssa::resolve_is_unconstrained, "Resolving IsUnconstrained") + .run_pass(|ssa| ssa.inline_functions(options.inliner_aggressiveness), "Inlining (1st)") // Run mem2reg with the CFG separated into blocks - .run_pass(Ssa::mem2reg, "After Mem2Reg (1st):") - .run_pass(Ssa::simplify_cfg, "After Simplifying (1st):") - .run_pass(Ssa::as_slice_optimization, "After `as_slice` optimization") + .run_pass(Ssa::mem2reg, "Mem2Reg (1st)") + .run_pass(Ssa::simplify_cfg, "Simplifying (1st)") + .run_pass(Ssa::as_slice_optimization, "`as_slice` optimization") .try_run_pass( Ssa::evaluate_static_assert_and_assert_constant, - "After `static_assert` and `assert_constant`:", + "`static_assert` and `assert_constant`", + )? + .run_pass(Ssa::loop_invariant_code_motion, "Loop Invariant Code Motion") + .try_run_pass( + |ssa| ssa.unroll_loops_iteratively(options.max_bytecode_increase_percent), + "Unrolling", )? - .run_pass(Ssa::loop_invariant_code_motion, "After Loop Invariant Code Motion:") - .try_run_pass(Ssa::unroll_loops_iteratively, "After Unrolling:")? - .run_pass(Ssa::simplify_cfg, "After Simplifying (2nd):") - .run_pass(Ssa::flatten_cfg, "After Flattening:") - .run_pass(Ssa::remove_bit_shifts, "After Removing Bit Shifts:") + .run_pass(Ssa::simplify_cfg, "Simplifying (2nd)") + .run_pass(Ssa::flatten_cfg, "Flattening") + .run_pass(Ssa::remove_bit_shifts, "After Removing Bit Shifts") // Run mem2reg once more with the flattened CFG to catch any remaining loads/stores - .run_pass(Ssa::mem2reg, "After Mem2Reg (2nd):") + .run_pass(Ssa::mem2reg, "Mem2Reg (2nd)") // Run the inlining pass again to handle functions with `InlineType::NoPredicates`. // Before flattening is run, we treat functions marked with the `InlineType::NoPredicates` as an entry point. // This pass must come immediately following `mem2reg` as the succeeding passes // may create an SSA which inlining fails to handle. .run_pass( |ssa| ssa.inline_functions_with_no_predicates(options.inliner_aggressiveness), - "After Inlining (2nd):", + "Inlining (2nd)", ) - .run_pass(Ssa::remove_if_else, "After Remove IfElse:") - .run_pass(Ssa::fold_constants, "After Constant Folding:") - .run_pass(Ssa::remove_enable_side_effects, "After EnableSideEffectsIf removal:") - .run_pass(Ssa::fold_constants_using_constraints, "After Constraint Folding:") - .run_pass(Ssa::dead_instruction_elimination, "After Dead Instruction Elimination:") - .run_pass(Ssa::simplify_cfg, "After Simplifying:") - .run_pass(Ssa::array_set_optimization, "After Array Set Optimizations:") + .run_pass(Ssa::remove_if_else, "Remove IfElse") + .run_pass(Ssa::fold_constants, "Constant Folding") + .run_pass(Ssa::remove_enable_side_effects, "EnableSideEffectsIf removal") + .run_pass(Ssa::fold_constants_using_constraints, "Constraint Folding") + .run_pass(Ssa::dead_instruction_elimination, "Dead Instruction Elimination (1st)") + .run_pass(Ssa::simplify_cfg, "Simplifying:") + .run_pass(Ssa::array_set_optimization, "Array Set Optimizations") .finish(); let ssa_level_warnings = if options.skip_underconstrained_check { @@ -146,14 +161,11 @@ pub(crate) fn optimize_into_acir( let ssa = SsaBuilder { ssa, - print_ssa_passes: options.enable_ssa_logging, + ssa_logging: options.ssa_logging.clone(), print_codegen_timings: options.print_codegen_timings, } - .run_pass( - |ssa| ssa.fold_constants_with_brillig(&brillig), - "After Constant Folding with Brillig:", - ) - .run_pass(Ssa::dead_instruction_elimination, "After Dead Instruction Elimination:") + .run_pass(|ssa| ssa.fold_constants_with_brillig(&brillig), "Inlining Brillig Calls Inlining") + .run_pass(Ssa::dead_instruction_elimination, "Dead Instruction Elimination (2nd)") .finish(); drop(ssa_gen_span_guard); @@ -226,7 +238,7 @@ impl SsaProgramArtifact { } } -/// Compiles the [`Program`] into [`ACIR``][acvm::acir::circuit::Program]. +/// Compiles the [`Program`] into [`ACIR`][acvm::acir::circuit::Program]. /// /// The output ACIR is backend-agnostic and so must go through a transformation pass before usage in proof generation. #[tracing::instrument(level = "trace", skip_all)] @@ -411,14 +423,14 @@ fn split_public_and_private_inputs( // This is just a convenience object to bundle the ssa with `print_ssa_passes` for debug printing. struct SsaBuilder { ssa: Ssa, - print_ssa_passes: bool, + ssa_logging: SsaLogging, print_codegen_timings: bool, } impl SsaBuilder { fn new( program: Program, - print_ssa_passes: bool, + ssa_logging: SsaLogging, force_brillig_runtime: bool, print_codegen_timings: bool, emit_ssa: &Option, @@ -433,7 +445,7 @@ impl SsaBuilder { let ssa_path = emit_ssa.with_extension("ssa.json"); write_to_file(&serde_json::to_vec(&ssa).unwrap(), &ssa_path); } - Ok(SsaBuilder { print_ssa_passes, print_codegen_timings, ssa }.print("Initial SSA:")) + Ok(SsaBuilder { ssa_logging, print_codegen_timings, ssa }.print("Initial SSA:")) } fn finish(self) -> Ssa { @@ -450,19 +462,28 @@ impl SsaBuilder { } /// The same as `run_pass` but for passes that may fail - fn try_run_pass( - mut self, - pass: fn(Ssa) -> Result, - msg: &str, - ) -> Result { + fn try_run_pass(mut self, pass: F, msg: &str) -> Result + where + F: FnOnce(Ssa) -> Result, + { self.ssa = time(msg, self.print_codegen_timings, || pass(self.ssa))?; Ok(self.print(msg)) } fn print(mut self, msg: &str) -> Self { - if self.print_ssa_passes { + let print_ssa_pass = match &self.ssa_logging { + SsaLogging::None => false, + SsaLogging::All => true, + SsaLogging::Contains(string) => { + let string = string.to_lowercase(); + let string = string.strip_prefix("after ").unwrap_or(&string); + let string = string.strip_suffix(':').unwrap_or(string); + msg.to_lowercase().contains(string) + } + }; + if print_ssa_pass { self.ssa.normalize_ids(); - println!("{msg}\n{}", self.ssa); + println!("After {msg}:\n{}", self.ssa); } self } diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs b/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs index e4a2eeb8c22..bd2585a3bfa 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs @@ -160,7 +160,7 @@ impl FunctionBuilder { for value in values { self.add_to_data_bus(*value, &mut databus); } - let len = databus.values.len(); + let len = databus.values.len() as u32; let array = (len > 0 && matches!(self.current_function.runtime(), RuntimeType::Acir(_))) .then(|| { @@ -223,9 +223,11 @@ impl FunctionBuilder { ssa_params: &[ValueId], mut flattened_params_databus_visibility: Vec, ) -> Vec { - let ssa_param_sizes: Vec<_> = ssa_params + let ssa_param_sizes: Vec = ssa_params .iter() - .map(|ssa_param| self.current_function.dfg[*ssa_param].get_type().flattened_size()) + .map(|ssa_param| { + self.current_function.dfg[*ssa_param].get_type().flattened_size() as usize + }) .collect(); let mut is_ssa_params_databus = Vec::with_capacity(ssa_params.len()); diff --git a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs index e3f3f33682b..827944e22d1 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs @@ -307,13 +307,13 @@ impl DataFlowGraph { instruction_id: InstructionId, ctrl_typevars: Option>, ) { - self.results.insert(instruction_id, Default::default()); + let result_types = self.instruction_result_types(instruction_id, ctrl_typevars); + let results = vecmap(result_types.into_iter().enumerate(), |(position, typ)| { + let instruction = instruction_id; + self.values.insert(Value::Instruction { typ, position, instruction }) + }); - // Get all of the types that this instruction produces - // and append them as results. - for typ in self.instruction_result_types(instruction_id, ctrl_typevars) { - self.append_result(instruction_id, typ); - } + self.results.insert(instruction_id, results); } /// Return the result types of this instruction. @@ -370,22 +370,6 @@ impl DataFlowGraph { matches!(self.values[value].get_type(), Type::Reference(_)) } - /// Appends a result type to the instruction. - pub(crate) fn append_result(&mut self, instruction_id: InstructionId, typ: Type) -> ValueId { - let results = self.results.get_mut(&instruction_id).unwrap(); - let expected_res_position = results.len(); - - let value_id = self.values.insert(Value::Instruction { - typ, - position: expected_res_position, - instruction: instruction_id, - }); - - // Add value to the list of results for this instruction - results.push(value_id); - value_id - } - /// Replaces an instruction result with a fresh id. pub(crate) fn replace_result( &mut self, @@ -463,7 +447,7 @@ impl DataFlowGraph { /// If this value is an array, return the length of the array as indicated by its type. /// Otherwise, return None. - pub(crate) fn try_get_array_length(&self, value: ValueId) -> Option { + pub(crate) fn try_get_array_length(&self, value: ValueId) -> Option { match self.type_of_value(value) { Type::Array(_, length) => Some(length), _ => None, diff --git a/compiler/noirc_evaluator/src/ssa/ir/dom.rs b/compiler/noirc_evaluator/src/ssa/ir/dom.rs index c1a7f14e0d1..504eecf4801 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/dom.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/dom.rs @@ -119,6 +119,29 @@ impl DominatorTree { } } + /// Walk up the dominator tree until we find a block for which `f` returns `Some` value. + /// Otherwise return `None` when we reach the top. + /// + /// Similar to `Iterator::filter_map` but only returns the first hit. + pub(crate) fn find_map_dominator( + &self, + mut block_id: BasicBlockId, + f: impl Fn(BasicBlockId) -> Option, + ) -> Option { + if !self.is_reachable(block_id) { + return None; + } + loop { + if let Some(value) = f(block_id) { + return Some(value); + } + block_id = match self.immediate_dominator(block_id) { + Some(immediate_dominator) => immediate_dominator, + None => return None, + } + } + } + /// Allocate and compute a dominator tree from a pre-computed control flow graph and /// post-order counterpart. pub(crate) fn with_cfg_and_post_order(cfg: &ControlFlowGraph, post_order: &PostOrder) -> Self { @@ -448,4 +471,22 @@ mod tests { assert!(dt.dominates(block2_id, block1_id)); assert!(dt.dominates(block2_id, block2_id)); } + + #[test] + fn test_find_map_dominator() { + let (dt, b0, b1, b2, _b3) = unreachable_node_setup(); + + assert_eq!( + dt.find_map_dominator(b2, |b| if b == b0 { Some("root") } else { None }), + Some("root") + ); + assert_eq!( + dt.find_map_dominator(b1, |b| if b == b0 { Some("unreachable") } else { None }), + None + ); + assert_eq!( + dt.find_map_dominator(b1, |b| if b == b1 { Some("not part of tree") } else { None }), + None + ); + } } diff --git a/compiler/noirc_evaluator/src/ssa/ir/function.rs b/compiler/noirc_evaluator/src/ssa/ir/function.rs index b1233e3063e..6413107c04a 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/function.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/function.rs @@ -197,6 +197,12 @@ impl Function { } } +impl Clone for Function { + fn clone(&self) -> Self { + Function::clone_with_id(self.id(), self) + } +} + impl std::fmt::Display for RuntimeType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index 9958aadd443..0d7b9d10f4c 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -45,8 +45,7 @@ pub(crate) type InstructionId = Id; /// - Opcodes which the IR knows the target machine has /// special support for. (LowLevel) /// - Opcodes which have no function definition in the -/// source code and must be processed by the IR. An example -/// of this is println. +/// source code and must be processed by the IR. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub(crate) enum Intrinsic { ArrayLen, @@ -112,6 +111,9 @@ impl Intrinsic { /// Returns whether the `Intrinsic` has side effects. /// /// If there are no side effects then the `Intrinsic` can be removed if the result is unused. + /// + /// An example of a side effect is increasing the reference count of an array, but functions + /// which can fail due to implicit constraints are also considered to have a side effect. pub(crate) fn has_side_effects(&self) -> bool { match self { Intrinsic::AssertConstant @@ -152,6 +154,39 @@ impl Intrinsic { } } + /// Intrinsics which only have a side effect due to the chance that + /// they can fail a constraint can be deduplicated. + pub(crate) fn can_be_deduplicated(&self, deduplicate_with_predicate: bool) -> bool { + match self { + // These apply a constraint in the form of ACIR opcodes, but they can be deduplicated + // if the inputs are the same. If they depend on a side effect variable (e.g. because + // they were in an if-then-else) then `handle_instruction_side_effects` in `flatten_cfg` + // will have attached the condition variable to their inputs directly, so they don't + // directly depend on the corresponding `enable_side_effect` instruction any more. + // However, to conform with the expectations of `Instruction::can_be_deduplicated` and + // `constant_folding` we only use this information if the caller shows interest in it. + Intrinsic::ToBits(_) + | Intrinsic::ToRadix(_) + | Intrinsic::BlackBox( + BlackBoxFunc::MultiScalarMul + | BlackBoxFunc::EmbeddedCurveAdd + | BlackBoxFunc::RecursiveAggregation, + ) => deduplicate_with_predicate, + + // Operations that remove items from a slice don't modify the slice, they just assert it's non-empty. + Intrinsic::SlicePopBack | Intrinsic::SlicePopFront | Intrinsic::SliceRemove => { + deduplicate_with_predicate + } + + Intrinsic::AssertConstant + | Intrinsic::StaticAssert + | Intrinsic::ApplyRangeConstraint + | Intrinsic::AsWitness => deduplicate_with_predicate, + + _ => !self.has_side_effects(), + } + } + /// Lookup an Intrinsic by name and return it if found. /// If there is no such intrinsic by that name, None is returned. pub(crate) fn lookup(name: &str) -> Option { @@ -245,7 +280,7 @@ pub(crate) enum Instruction { /// - `code1` will have side effects iff `condition1` evaluates to `true` /// /// This instruction is only emitted after the cfg flattening pass, and is used to annotate - /// instruction regions with an condition that corresponds to their position in the CFG's + /// instruction regions with a condition that corresponds to their position in the CFG's /// if-branching structure. EnableSideEffectsIf { condition: ValueId }, @@ -327,10 +362,66 @@ impl Instruction { matches!(self.result_type(), InstructionResultType::Unknown) } + /// Indicates if the instruction has a side effect, ie. it can fail, or it interacts with memory. + /// + /// This is similar to `can_be_deduplicated`, but it doesn't depend on whether the caller takes + /// constraints into account, because it might not use it to isolate the side effects across branches. + pub(crate) fn has_side_effects(&self, dfg: &DataFlowGraph) -> bool { + use Instruction::*; + + match self { + // These either have side-effects or interact with memory + EnableSideEffectsIf { .. } + | Allocate + | Load { .. } + | Store { .. } + | IncrementRc { .. } + | DecrementRc { .. } => true, + + Call { func, .. } => match dfg[*func] { + Value::Intrinsic(intrinsic) => intrinsic.has_side_effects(), + _ => true, // Be conservative and assume other functions can have side effects. + }, + + // These can fail. + Constrain(..) | RangeCheck { .. } => true, + + // This should never be side-effectful + MakeArray { .. } => false, + + // Some binary math can overflow or underflow + Binary(binary) => match binary.operator { + BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Mod => { + true + } + BinaryOp::Eq + | BinaryOp::Lt + | BinaryOp::And + | BinaryOp::Or + | BinaryOp::Xor + | BinaryOp::Shl + | BinaryOp::Shr => false, + }, + + // These can have different behavior depending on the EnableSideEffectsIf context. + Cast(_, _) + | Not(_) + | Truncate { .. } + | IfElse { .. } + | ArrayGet { .. } + | ArraySet { .. } => self.requires_acir_gen_predicate(dfg), + } + } + /// Indicates if the instruction can be safely replaced with the results of another instruction with the same inputs. /// If `deduplicate_with_predicate` is set, we assume we're deduplicating with the instruction /// and its predicate, rather than just the instruction. Setting this means instructions that /// rely on predicates can be deduplicated as well. + /// + /// Some instructions get the predicate attached to their inputs by `handle_instruction_side_effects` in `flatten_cfg`. + /// These can be deduplicated because they implicitly depend on the predicate, not only when the caller uses the + /// predicate variable as a key to cache results. However, to avoid tight coupling between passes, we make the deduplication + /// conditional on whether the caller wants the predicate to be taken into account or not. pub(crate) fn can_be_deduplicated( &self, dfg: &DataFlowGraph, @@ -348,7 +439,9 @@ impl Instruction { | DecrementRc { .. } => false, Call { func, .. } => match dfg[*func] { - Value::Intrinsic(intrinsic) => !intrinsic.has_side_effects(), + Value::Intrinsic(intrinsic) => { + intrinsic.can_be_deduplicated(deduplicate_with_predicate) + } _ => false, }, @@ -421,6 +514,7 @@ impl Instruction { // Explicitly allows removal of unused ec operations, even if they can fail Value::Intrinsic(Intrinsic::BlackBox(BlackBoxFunc::MultiScalarMul)) | Value::Intrinsic(Intrinsic::BlackBox(BlackBoxFunc::EmbeddedCurveAdd)) => true, + Value::Intrinsic(intrinsic) => !intrinsic.has_side_effects(), // All foreign functions are treated as having side effects. @@ -436,7 +530,7 @@ impl Instruction { } } - /// If true the instruction will depends on enable_side_effects context during acir-gen + /// If true the instruction will depend on `enable_side_effects` context during acir-gen. pub(crate) fn requires_acir_gen_predicate(&self, dfg: &DataFlowGraph) -> bool { match self { Instruction::Binary(binary) diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs index 4be37b3c626..46ca7bb8c24 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs @@ -45,39 +45,47 @@ pub(super) fn simplify_call( _ => return SimplifyResult::None, }; + let return_type = ctrl_typevars.and_then(|return_types| return_types.first().cloned()); + let constant_args: Option> = arguments.iter().map(|value_id| dfg.get_numeric_constant(*value_id)).collect(); - match intrinsic { + let simplified_result = match intrinsic { Intrinsic::ToBits(endian) => { // TODO: simplify to a range constraint if `limb_count == 1` - if let (Some(constant_args), Some(return_type)) = - (constant_args, ctrl_typevars.map(|return_types| return_types.first().cloned())) - { + if let (Some(constant_args), Some(return_type)) = (constant_args, return_type.clone()) { let field = constant_args[0]; - let limb_count = if let Some(Type::Array(_, array_len)) = return_type { - array_len as u32 + let limb_count = if let Type::Array(_, array_len) = return_type { + array_len } else { unreachable!("ICE: Intrinsic::ToRadix return type must be array") }; - constant_to_radix(endian, field, 2, limb_count, dfg, block, call_stack) + constant_to_radix(endian, field, 2, limb_count, |values| { + make_constant_array(dfg, values.into_iter(), Type::bool(), block, call_stack) + }) } else { SimplifyResult::None } } Intrinsic::ToRadix(endian) => { // TODO: simplify to a range constraint if `limb_count == 1` - if let (Some(constant_args), Some(return_type)) = - (constant_args, ctrl_typevars.map(|return_types| return_types.first().cloned())) - { + if let (Some(constant_args), Some(return_type)) = (constant_args, return_type.clone()) { let field = constant_args[0]; let radix = constant_args[1].to_u128() as u32; - let limb_count = if let Some(Type::Array(_, array_len)) = return_type { - array_len as u32 + let limb_count = if let Type::Array(_, array_len) = return_type { + array_len } else { unreachable!("ICE: Intrinsic::ToRadix return type must be array") }; - constant_to_radix(endian, field, radix, limb_count, dfg, block, call_stack) + constant_to_radix(endian, field, radix, limb_count, |values| { + make_constant_array( + dfg, + values.into_iter(), + Type::unsigned(8), + block, + call_stack, + ) + }) } else { SimplifyResult::None } @@ -330,7 +338,7 @@ pub(super) fn simplify_call( } Intrinsic::FromField => { let incoming_type = Type::field(); - let target_type = ctrl_typevars.unwrap().remove(0); + let target_type = return_type.clone().unwrap(); let truncate = Instruction::Truncate { value: arguments[0], @@ -352,8 +360,8 @@ pub(super) fn simplify_call( Intrinsic::AsWitness => SimplifyResult::None, Intrinsic::IsUnconstrained => SimplifyResult::None, Intrinsic::DerivePedersenGenerators => { - if let Some(Type::Array(_, len)) = ctrl_typevars.unwrap().first() { - simplify_derive_generators(dfg, arguments, *len as u32, block, call_stack) + if let Some(Type::Array(_, len)) = return_type.clone() { + simplify_derive_generators(dfg, arguments, len, block, call_stack) } else { unreachable!("Derive Pedersen Generators must return an array"); } @@ -370,7 +378,19 @@ pub(super) fn simplify_call( } Intrinsic::ArrayRefCount => SimplifyResult::None, Intrinsic::SliceRefCount => SimplifyResult::None, + }; + + if let (Some(expected_types), SimplifyResult::SimplifiedTo(result)) = + (return_type, &simplified_result) + { + assert_eq!( + dfg.type_of_value(*result), + expected_types, + "Simplification should not alter return type" + ); } + + simplified_result } /// Slices have a tuple structure (slice length, slice contents) to enable logic @@ -422,8 +442,8 @@ fn simplify_slice_push_back( for elem in &arguments[2..] { slice.push_back(*elem); } - let slice_size = slice.len(); - let element_size = element_type.element_size(); + let slice_size = slice.len() as u32; + let element_size = element_type.element_size() as u32; let new_slice = make_array(dfg, slice, element_type, block, &call_stack); let set_last_slice_value_instr = Instruction::ArraySet { @@ -612,7 +632,7 @@ fn make_constant_array( let result_constants: im::Vector<_> = results.map(|element| dfg.make_constant(element, typ.clone())).collect(); - let typ = Type::Array(Arc::new(vec![typ]), result_constants.len()); + let typ = Type::Array(Arc::new(vec![typ]), result_constants.len() as u32); make_array(dfg, result_constants, typ, block, call_stack) } @@ -651,9 +671,7 @@ fn constant_to_radix( field: FieldElement, radix: u32, limb_count: u32, - dfg: &mut DataFlowGraph, - block: BasicBlockId, - call_stack: &CallStack, + mut make_array: impl FnMut(Vec) -> ValueId, ) -> SimplifyResult { let bit_size = u32::BITS - (radix - 1).leading_zeros(); let radix_big = BigUint::from(radix); @@ -674,13 +692,7 @@ fn constant_to_radix( if endian == Endian::Big { limbs.reverse(); } - let result_array = make_constant_array( - dfg, - limbs.into_iter(), - Type::unsigned(bit_size), - block, - call_stack, - ); + let result_array = make_array(limbs); SimplifyResult::SimplifiedTo(result_array) } } @@ -807,7 +819,7 @@ fn simplify_derive_generators( results.push(dfg.make_constant(y, Type::field())); results.push(is_infinite); } - let len = results.len(); + let len = results.len() as u32; let typ = Type::Array(vec![Type::field(), Type::field(), Type::unsigned(1)].into(), len / 3); let result = make_array(dfg, results.into(), typ, block, call_stack); diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs index 4f2a31e2fb0..301b75e0bd4 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs @@ -48,7 +48,7 @@ pub(super) fn simplify_ec_add( let result_x = dfg.make_constant(result_x, Type::field()); let result_y = dfg.make_constant(result_y, Type::field()); - let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool()); + let result_is_infinity = dfg.make_constant(result_is_infinity, Type::field()); let typ = Type::Array(Arc::new(vec![Type::field()]), 3); @@ -107,7 +107,7 @@ pub(super) fn simplify_msm( let result_x = dfg.make_constant(result_x, Type::field()); let result_y = dfg.make_constant(result_y, Type::field()); - let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool()); + let result_is_infinity = dfg.make_constant(result_is_infinity, Type::field()); let elements = im::vector![result_x, result_y, result_is_infinity]; let typ = Type::Array(Arc::new(vec![Type::field()]), 3); diff --git a/compiler/noirc_evaluator/src/ssa/ir/types.rs b/compiler/noirc_evaluator/src/ssa/ir/types.rs index 130f1d59e46..16f4b8d2431 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/types.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/types.rs @@ -75,7 +75,7 @@ pub(crate) enum Type { Reference(Arc), /// An immutable array value with the given element type and length - Array(Arc, usize), + Array(Arc, u32), /// An immutable slice value with a given element type Slice(Arc), @@ -111,7 +111,7 @@ impl Type { } /// Creates the str type, of the given length N - pub(crate) fn str(length: usize) -> Type { + pub(crate) fn str(length: u32) -> Type { Type::Array(Arc::new(vec![Type::char()]), length) } @@ -161,7 +161,7 @@ impl Type { } /// Returns the flattened size of a Type - pub(crate) fn flattened_size(&self) -> usize { + pub(crate) fn flattened_size(&self) -> u32 { match self { Type::Array(elements, len) => { elements.iter().fold(0, |sum, elem| sum + (elem.flattened_size() * len)) diff --git a/compiler/noirc_evaluator/src/ssa/opt/as_slice_length.rs b/compiler/noirc_evaluator/src/ssa/opt/as_slice_length.rs index 76705dcc9db..75cdea349b7 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/as_slice_length.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/as_slice_length.rs @@ -33,7 +33,7 @@ impl Function { } } -fn known_slice_lengths(func: &Function) -> HashMap { +fn known_slice_lengths(func: &Function) -> HashMap { let mut known_slice_lengths = HashMap::default(); for block_id in func.reachable_blocks() { let block = &func.dfg[block_id]; @@ -61,7 +61,7 @@ fn known_slice_lengths(func: &Function) -> HashMap { fn replace_known_slice_lengths( func: &mut Function, - known_slice_lengths: HashMap, + known_slice_lengths: HashMap, ) { known_slice_lengths.into_iter().for_each(|(instruction_id, known_length)| { let call_returns = func.dfg.instruction_results(instruction_id); diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index 019bace33a3..d1d84c305d2 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -6,7 +6,7 @@ //! by the [`DataFlowGraph`] automatically as new instructions are pushed. //! - Check whether any input values have been constrained to be equal to a value of a simpler form //! by a [constrain instruction][Instruction::Constrain]. If so, replace the input value with the simpler form. -//! - Check whether the instruction [can_be_replaced][Instruction::can_be_replaced()] +//! - Check whether the instruction [can_be_deduplicated][Instruction::can_be_deduplicated()] //! by duplicate instruction earlier in the same block. //! //! These operations are done in parallel so that they can each benefit from each other @@ -54,6 +54,9 @@ use fxhash::FxHashMap as HashMap; impl Ssa { /// Performs constant folding on each instruction. /// + /// It will not look at constraints to inform simplifications + /// based on the stated equivalence of two instructions. + /// /// See [`constant_folding`][self] module for more information. #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn fold_constants(mut self) -> Ssa { @@ -146,7 +149,8 @@ impl Function { use_constraint_info: bool, brillig_info: Option, ) { - let mut context = Context::new(self, use_constraint_info, brillig_info); + let mut context = Context::new(use_constraint_info, brillig_info); + let mut dom = DominatorTree::with_function(self); context.block_queue.push_back(self.entry_block()); while let Some(block) = context.block_queue.pop_front() { @@ -155,7 +159,7 @@ impl Function { } context.visited_blocks.insert(block); - context.fold_constants_in_block(self, block); + context.fold_constants_in_block(&mut self.dfg, &mut dom, block); } } } @@ -174,12 +178,10 @@ struct Context<'a> { /// We partition the maps of constrained values according to the side-effects flag at the point /// at which the values are constrained. This prevents constraints which are only sometimes enforced /// being used to modify the rest of the program. - constraint_simplification_mappings: HashMap>, + constraint_simplification_mappings: ConstraintSimplificationCache, // Cache of instructions without any side-effects along with their outputs. cached_instruction_results: InstructionResultCache, - - dom: DominatorTree, } #[derive(Copy, Clone)] @@ -188,9 +190,56 @@ pub(crate) struct BrilligInfo<'a> { brillig_functions: &'a BTreeMap, } -/// HashMap from (Instruction, side_effects_enabled_var) to the results of the instruction. +/// Records a simplified equivalents of an [`Instruction`] in the blocks +/// where the constraint that advised the simplification has been encountered. +/// +/// For more information see [`ConstraintSimplificationCache`]. +#[derive(Default)] +struct SimplificationCache { + /// Simplified expressions where we found them. + /// + /// It will always have at least one value because `add` is called + /// after the default is constructed. + simplifications: HashMap, +} + +impl SimplificationCache { + /// Called with a newly encountered simplification. + fn add(&mut self, dfg: &DataFlowGraph, simple: ValueId, block: BasicBlockId) { + self.simplifications + .entry(block) + .and_modify(|existing| { + // `SimplificationCache` may already hold a simplification in this block + // so we check whether `simple` is a better simplification than the current one. + if let Some((_, simpler)) = simplify(dfg, *existing, simple) { + *existing = simpler; + }; + }) + .or_insert(simple); + } + + /// Try to find a simplification in a visible block. + fn get(&self, block: BasicBlockId, dom: &DominatorTree) -> Option { + // Deterministically walk up the dominator chain until we encounter a block that contains a simplification. + dom.find_map_dominator(block, |b| self.simplifications.get(&b).cloned()) + } +} + +/// HashMap from `(side_effects_enabled_var, Instruction)` to a simplified expression that it can +/// be replaced with based on constraints that testify to their equivalence, stored together +/// with the set of blocks at which this constraint has been observed. +/// +/// Only blocks dominated by one in the cache should have access to this information, otherwise +/// we create a sort of time paradox where we replace an instruction with a constant we believe +/// it _should_ equal to, without ever actually producing and asserting the value. +type ConstraintSimplificationCache = HashMap>; + +/// HashMap from `(Instruction, side_effects_enabled_var)` to the results of the instruction. /// Stored as a two-level map to avoid cloning Instructions during the `.get` call. /// +/// The `side_effects_enabled_var` is optional because we only use them when `Instruction::requires_acir_gen_predicate` +/// is true _and_ the constraint information is also taken into account. +/// /// In addition to each result, the original BasicBlockId is stored as well. This allows us /// to deduplicate instructions across blocks as long as the new block dominates the original. type InstructionResultCache = HashMap, ResultCache>>; @@ -204,11 +253,7 @@ struct ResultCache { } impl<'brillig> Context<'brillig> { - fn new( - function: &Function, - use_constraint_info: bool, - brillig_info: Option>, - ) -> Self { + fn new(use_constraint_info: bool, brillig_info: Option>) -> Self { Self { use_constraint_info, brillig_info, @@ -216,48 +261,57 @@ impl<'brillig> Context<'brillig> { block_queue: Default::default(), constraint_simplification_mappings: Default::default(), cached_instruction_results: Default::default(), - dom: DominatorTree::with_function(function), } } - fn fold_constants_in_block(&mut self, function: &mut Function, block: BasicBlockId) { - let instructions = function.dfg[block].take_instructions(); + fn fold_constants_in_block( + &mut self, + dfg: &mut DataFlowGraph, + dom: &mut DominatorTree, + block: BasicBlockId, + ) { + let instructions = dfg[block].take_instructions(); - let mut side_effects_enabled_var = - function.dfg.make_constant(FieldElement::one(), Type::bool()); + // Default side effect condition variable with an enabled state. + let mut side_effects_enabled_var = dfg.make_constant(FieldElement::one(), Type::bool()); for instruction_id in instructions { self.fold_constants_into_instruction( - &mut function.dfg, + dfg, + dom, block, instruction_id, &mut side_effects_enabled_var, ); } - self.block_queue.extend(function.dfg[block].successors()); + self.block_queue.extend(dfg[block].successors()); } fn fold_constants_into_instruction( &mut self, dfg: &mut DataFlowGraph, + dom: &mut DominatorTree, mut block: BasicBlockId, id: InstructionId, side_effects_enabled_var: &mut ValueId, ) { let constraint_simplification_mapping = self.get_constraint_map(*side_effects_enabled_var); - let instruction = Self::resolve_instruction(id, dfg, constraint_simplification_mapping); + + let instruction = + Self::resolve_instruction(id, block, dfg, dom, constraint_simplification_mapping); + let old_results = dfg.instruction_results(id).to_vec(); // If a copy of this instruction exists earlier in the block, then reuse the previous results. if let Some(cache_result) = - self.get_cached(dfg, &instruction, *side_effects_enabled_var, block) + self.get_cached(dfg, dom, &instruction, *side_effects_enabled_var, block) { match cache_result { CacheResult::Cached(cached) => { Self::replace_result_ids(dfg, &old_results, cached); return; } - CacheResult::NeedToHoistToCommonBlock(dominator, _cached) => { + CacheResult::NeedToHoistToCommonBlock(dominator) => { // Just change the block to insert in the common dominator instead. // This will only move the current instance of the instruction right now. // When constant folding is run a second time later on, it'll catch @@ -265,7 +319,7 @@ impl<'brillig> Context<'brillig> { block = dominator; } } - } + }; let new_results = // First try to inline a call to a brillig function with all constant arguments. @@ -307,8 +361,10 @@ impl<'brillig> Context<'brillig> { /// Fetches an [`Instruction`] by its [`InstructionId`] and fully resolves its inputs. fn resolve_instruction( instruction_id: InstructionId, + block: BasicBlockId, dfg: &DataFlowGraph, - constraint_simplification_mapping: &HashMap, + dom: &mut DominatorTree, + constraint_simplification_mapping: &HashMap, ) -> Instruction { let instruction = dfg[instruction_id].clone(); @@ -318,20 +374,29 @@ impl<'brillig> Context<'brillig> { // This allows us to reach a stable final `ValueId` for each instruction input as we add more // constraints to the cache. fn resolve_cache( + block: BasicBlockId, dfg: &DataFlowGraph, - cache: &HashMap, + dom: &mut DominatorTree, + cache: &HashMap, value_id: ValueId, ) -> ValueId { let resolved_id = dfg.resolve(value_id); match cache.get(&resolved_id) { - Some(cached_value) => resolve_cache(dfg, cache, *cached_value), + Some(simplification_cache) => { + if let Some(simplified) = simplification_cache.get(block, dom) { + resolve_cache(block, dfg, dom, cache, simplified) + } else { + resolved_id + } + } None => resolved_id, } } // Resolve any inputs to ensure that we're comparing like-for-like instructions. - instruction - .map_values(|value_id| resolve_cache(dfg, constraint_simplification_mapping, value_id)) + instruction.map_values(|value_id| { + resolve_cache(block, dfg, dom, constraint_simplification_mapping, value_id) + }) } /// Pushes a new [`Instruction`] into the [`DataFlowGraph`] which applies any optimizations @@ -377,32 +442,43 @@ impl<'brillig> Context<'brillig> { // to map from the more complex to the simpler value. if let Instruction::Constrain(lhs, rhs, _) = instruction { // These `ValueId`s should be fully resolved now. - match (&dfg[lhs], &dfg[rhs]) { - // Ignore trivial constraints - (Value::NumericConstant { .. }, Value::NumericConstant { .. }) => (), - - // Prefer replacing with constants where possible. - (Value::NumericConstant { .. }, _) => { - self.get_constraint_map(side_effects_enabled_var).insert(rhs, lhs); - } - (_, Value::NumericConstant { .. }) => { - self.get_constraint_map(side_effects_enabled_var).insert(lhs, rhs); - } - // Otherwise prefer block parameters over instruction results. - // This is as block parameters are more likely to be a single witness rather than a full expression. - (Value::Param { .. }, Value::Instruction { .. }) => { - self.get_constraint_map(side_effects_enabled_var).insert(rhs, lhs); - } - (Value::Instruction { .. }, Value::Param { .. }) => { - self.get_constraint_map(side_effects_enabled_var).insert(lhs, rhs); - } - (_, _) => (), + if let Some((complex, simple)) = simplify(dfg, lhs, rhs) { + self.get_constraint_map(side_effects_enabled_var) + .entry(complex) + .or_default() + .add(dfg, simple, block); } } } + // If we have an array get whose value is from an array set on the same array at the same index, + // we can simplify that array get to the value of the previous array set. + // + // For example: + // v3 = array_set v0, index v1, value v2 + // v4 = array_get v3, index v1 -> Field + // + // We know that `v4` can be simplified to `v2`. + // Thus, even if the index is dynamic (meaning the array get would have side effects), + // we can simplify the operation when we take into account the predicate. + if let Instruction::ArraySet { index, value, .. } = &instruction { + let use_predicate = + self.use_constraint_info && instruction.requires_acir_gen_predicate(dfg); + let predicate = use_predicate.then_some(side_effects_enabled_var); + + let array_get = Instruction::ArrayGet { array: instruction_results[0], index: *index }; + + self.cached_instruction_results + .entry(array_get) + .or_default() + .entry(predicate) + .or_default() + .cache(block, vec![*value]); + } + // If the instruction doesn't have side-effects and if it won't interact with enable_side_effects during acir_gen, // we cache the results so we can reuse them if the same instruction appears again later in the block. + // Others have side effects representing failure, which are implicit in the ACIR code and can also be deduplicated. if instruction.can_be_deduplicated(dfg, self.use_constraint_info) { let use_predicate = self.use_constraint_info && instruction.requires_acir_gen_predicate(dfg); @@ -417,10 +493,12 @@ impl<'brillig> Context<'brillig> { } } + /// Get the simplification mapping from complex to simpler instructions, + /// which all depend on the same side effect condition variable. fn get_constraint_map( &mut self, side_effects_enabled_var: ValueId, - ) -> &mut HashMap { + ) -> &mut HashMap { self.constraint_simplification_mappings.entry(side_effects_enabled_var).or_default() } @@ -435,19 +513,20 @@ impl<'brillig> Context<'brillig> { } } + /// Get a cached result if it can be used in this context. fn get_cached( - &mut self, + &self, dfg: &DataFlowGraph, + dom: &mut DominatorTree, instruction: &Instruction, side_effects_enabled_var: ValueId, block: BasicBlockId, ) -> Option { let results_for_instruction = self.cached_instruction_results.get(instruction)?; - let predicate = self.use_constraint_info && instruction.requires_acir_gen_predicate(dfg); let predicate = predicate.then_some(side_effects_enabled_var); - results_for_instruction.get(&predicate)?.get(block, &mut self.dom) + results_for_instruction.get(&predicate)?.get(block, dom, instruction.has_side_effects(dfg)) } /// Checks if the given instruction is a call to a brillig function with all constant arguments. @@ -625,14 +704,21 @@ impl ResultCache { /// We require that the cached instruction's block dominates `block` in order to avoid /// cycles causing issues (e.g. two instructions being replaced with the results of each other /// such that neither instruction exists anymore.) - fn get(&self, block: BasicBlockId, dom: &mut DominatorTree) -> Option { - self.result.as_ref().map(|(origin_block, results)| { + fn get( + &self, + block: BasicBlockId, + dom: &mut DominatorTree, + has_side_effects: bool, + ) -> Option { + self.result.as_ref().and_then(|(origin_block, results)| { if dom.dominates(*origin_block, block) { - CacheResult::Cached(results) - } else { + Some(CacheResult::Cached(results)) + } else if !has_side_effects { // Insert a copy of this instruction in the common dominator let dominator = dom.common_dominator(*origin_block, block); - CacheResult::NeedToHoistToCommonBlock(dominator, results) + Some(CacheResult::NeedToHoistToCommonBlock(dominator)) + } else { + None } }) } @@ -640,7 +726,7 @@ impl ResultCache { enum CacheResult<'a> { Cached(&'a [ValueId]), - NeedToHoistToCommonBlock(BasicBlockId, &'a [ValueId]), + NeedToHoistToCommonBlock(BasicBlockId), } /// Result of trying to evaluate an instruction (any instruction) in this pass. @@ -665,7 +751,7 @@ pub(crate) fn type_to_brillig_parameter(typ: &Type) -> Option for item_typ in item_type.iter() { parameters.push(type_to_brillig_parameter(item_typ)?); } - Some(BrilligParameter::Array(parameters, *size)) + Some(BrilligParameter::Array(parameters, *size as usize)) } _ => None, } @@ -687,6 +773,25 @@ fn value_id_to_calldata(value_id: ValueId, dfg: &DataFlowGraph, calldata: &mut V panic!("Expected ValueId to be numeric constant or array constant"); } +/// Check if one expression is simpler than the other. +/// Returns `Some((complex, simple))` if a simplification was found, otherwise `None`. +/// Expects the `ValueId`s to be fully resolved. +fn simplify(dfg: &DataFlowGraph, lhs: ValueId, rhs: ValueId) -> Option<(ValueId, ValueId)> { + match (&dfg[lhs], &dfg[rhs]) { + // Ignore trivial constraints + (Value::NumericConstant { .. }, Value::NumericConstant { .. }) => None, + + // Prefer replacing with constants where possible. + (Value::NumericConstant { .. }, _) => Some((rhs, lhs)), + (_, Value::NumericConstant { .. }) => Some((lhs, rhs)), + // Otherwise prefer block parameters over instruction results. + // This is as block parameters are more likely to be a single witness rather than a full expression. + (Value::Param { .. }, Value::Instruction { .. }) => Some((rhs, lhs)), + (Value::Instruction { .. }, Value::Param { .. }) => Some((lhs, rhs)), + (_, _) => None, + } +} + #[cfg(test)] mod test { use std::sync::Arc; @@ -923,32 +1028,22 @@ mod test { // Regression for #4600 #[test] fn array_get_regression() { - // fn main f0 { - // b0(v0: u1, v1: u64): - // enable_side_effects_if v0 - // v2 = make_array [Field 0, Field 1] - // v3 = array_get v2, index v1 - // v4 = not v0 - // enable_side_effects_if v4 - // v5 = array_get v2, index v1 - // } - // // We want to make sure after constant folding both array_gets remain since they are // under different enable_side_effects_if contexts and thus one may be disabled while // the other is not. If one is removed, it is possible e.g. v4 is replaced with v2 which // is disabled (only gets from index 0) and thus returns the wrong result. let src = " - acir(inline) fn main f0 { - b0(v0: u1, v1: u64): - enable_side_effects v0 - v4 = make_array [Field 0, Field 1] : [Field; 2] - v5 = array_get v4, index v1 -> Field - v6 = not v0 - enable_side_effects v6 - v7 = array_get v4, index v1 -> Field - return - } - "; + acir(inline) fn main f0 { + b0(v0: u1, v1: u64): + enable_side_effects v0 + v4 = make_array [Field 0, Field 1] : [Field; 2] + v5 = array_get v4, index v1 -> Field + v6 = not v0 + enable_side_effects v6 + v7 = array_get v4, index v1 -> Field + return + } + "; let ssa = Ssa::from_str(src).unwrap(); // Expected output is unchanged @@ -1015,7 +1110,6 @@ mod test { // v5 = call keccakf1600(v1) // v6 = call keccakf1600(v2) // } - // // Here we're checking a situation where two identical arrays are being initialized twice and being assigned separate `ValueId`s. // This would result in otherwise identical instructions not being deduplicated. let main_id = Id::test_new(0); @@ -1118,7 +1212,7 @@ mod test { v2 = lt u32 1000, v0 jmpif v2 then: b1, else: b2 b1(): - v4 = add v0, u32 1 + v4 = shl v0, u32 1 v5 = lt v0, v4 constrain v5 == u1 1 jmp b2() @@ -1126,7 +1220,7 @@ mod test { v7 = lt u32 1000, v0 jmpif v7 then: b3, else: b4 b3(): - v8 = add v0, u32 1 + v8 = shl v0, u32 1 v9 = lt v0, v8 constrain v9 == u1 1 jmp b4() @@ -1144,10 +1238,10 @@ mod test { brillig(inline) fn main f0 { b0(v0: u32): v2 = lt u32 1000, v0 - v4 = add v0, u32 1 + v4 = shl v0, u32 1 jmpif v2 then: b1, else: b2 b1(): - v5 = add v0, u32 1 + v5 = shl v0, u32 1 v6 = lt v0, v5 constrain v6 == u1 1 jmp b2() @@ -1341,4 +1435,160 @@ mod test { let ssa = ssa.fold_constants_with_brillig(&brillig); assert_normalized_ssa_equals(ssa, expected); } + + #[test] + fn does_not_use_cached_constrain_in_block_that_is_not_dominated() { + let src = " + brillig(inline) fn main f0 { + b0(v0: Field, v1: Field): + v3 = eq v0, Field 0 + jmpif v3 then: b1, else: b2 + b1(): + v5 = eq v1, Field 1 + constrain v1 == Field 1 + jmp b2() + b2(): + v6 = eq v1, Field 0 + constrain v1 == Field 0 + return + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.fold_constants_using_constraints(); + assert_normalized_ssa_equals(ssa, src); + } + + #[test] + fn does_not_hoist_constrain_to_common_ancestor() { + let src = " + brillig(inline) fn main f0 { + b0(v0: Field, v1: Field): + v3 = eq v0, Field 0 + jmpif v3 then: b1, else: b2 + b1(): + constrain v1 == Field 1 + jmp b2() + b2(): + jmpif v0 then: b3, else: b4 + b3(): + constrain v1 == Field 1 // This was incorrectly hoisted to b0 but this condition is not valid when going b0 -> b2 -> b4 + jmp b4() + b4(): + return + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.fold_constants_using_constraints(); + assert_normalized_ssa_equals(ssa, src); + } + + #[test] + fn does_not_hoist_sub_to_common_ancestor() { + let src = " + acir(inline) fn main f0 { + b0(v0: u32): + v2 = eq v0, u32 0 + jmpif v2 then: b4, else: b1 + b4(): + v5 = sub v0, u32 1 + jmp b5() + b5(): + return + b1(): + jmpif v0 then: b3, else: b2 + b3(): + v4 = sub v0, u32 1 // We can't hoist this because v0 is zero here and it will lead to an underflow + jmp b5() + b2(): + jmp b5() + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.fold_constants_using_constraints(); + assert_normalized_ssa_equals(ssa, src); + } + + #[test] + fn deduplicates_side_effecting_intrinsics() { + let src = " + // After EnableSideEffectsIf removal: + acir(inline) fn main f0 { + b0(v0: Field, v1: Field, v2: u1): + v4 = call is_unconstrained() -> u1 + v7 = call to_be_radix(v0, u32 256) -> [u8; 1] // `a.to_be_radix(256)`; + inc_rc v7 + v8 = call to_be_radix(v0, u32 256) -> [u8; 1] // duplicate load of `a` + inc_rc v8 + v9 = cast v2 as Field // `if c { a.to_be_radix(256) }` + v10 = mul v0, v9 // attaching `c` to `a` + v11 = call to_be_radix(v10, u32 256) -> [u8; 1] // calling `to_radix(c * a)` + inc_rc v11 + enable_side_effects v2 // side effect var for `c` shifted down by removal + return + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let expected = " + acir(inline) fn main f0 { + b0(v0: Field, v1: Field, v2: u1): + v4 = call is_unconstrained() -> u1 + v7 = call to_be_radix(v0, u32 256) -> [u8; 1] + inc_rc v7 + inc_rc v7 + v8 = cast v2 as Field + v9 = mul v0, v8 + v10 = call to_be_radix(v9, u32 256) -> [u8; 1] + inc_rc v10 + enable_side_effects v2 + return + } + "; + let ssa = ssa.fold_constants_using_constraints(); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn array_get_from_array_set_with_different_predicates() { + let src = " + acir(inline) fn main f0 { + b0(v0: [Field; 3], v1: u32, v2: Field): + enable_side_effects u1 0 + v4 = array_set v0, index v1, value v2 + enable_side_effects u1 1 + v6 = array_get v4, index v1 -> Field + return v6 + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + + let ssa = ssa.fold_constants_using_constraints(); + // We expect the code to be unchanged + assert_normalized_ssa_equals(ssa, src); + } + + #[test] + fn array_get_from_array_set_same_predicates() { + let src = " + acir(inline) fn main f0 { + b0(v0: [Field; 3], v1: u32, v2: Field): + enable_side_effects u1 1 + v4 = array_set v0, index v1, value v2 + v6 = array_get v4, index v1 -> Field + return v6 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + + let expected = " + acir(inline) fn main f0 { + b0(v0: [Field; 3], v1: u32, v2: Field): + enable_side_effects u1 1 + v4 = array_set v0, index v1, value v2 + return v2 + } + "; + let ssa = ssa.fold_constants_using_constraints(); + assert_normalized_ssa_equals(ssa, expected); + } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index 61a93aee58d..c8dd0e3c5a3 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -131,7 +131,7 @@ //! v11 = mul v4, Field 12 //! v12 = add v10, v11 //! store v12 at v5 (new store) -use fxhash::FxHashMap as HashMap; +use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; use acvm::{acir::AcirField, acir::BlackBoxFunc, FieldElement}; use iter_extended::vecmap; @@ -201,6 +201,15 @@ struct Context<'f> { /// When processing a block, we pop this stack to get its arguments /// and at the end we push the arguments for his successor arguments_stack: Vec>, + + /// Stores all allocations local to the current branch. + /// + /// Since these branches are local to the current branch (i.e. only defined within one branch of + /// an if expression), they should not be merged with their previous value or stored value in + /// the other branch since there is no such value. + /// + /// The `ValueId` here is that which is returned by the allocate instruction. + local_allocations: HashSet, } #[derive(Clone)] @@ -211,6 +220,8 @@ struct ConditionalBranch { old_condition: ValueId, // The condition of the branch condition: ValueId, + // The allocations accumulated when processing the branch + local_allocations: HashSet, } struct ConditionalContext { @@ -243,6 +254,7 @@ fn flatten_function_cfg(function: &mut Function, no_predicates: &HashMap Context<'f> { // If this is not a separate variable, clippy gets confused and says the to_vec is // unnecessary, when removing it actually causes an aliasing/mutability error. let instructions = self.inserter.function.dfg[block].instructions().to_vec(); - let mut previous_allocate_result = None; - for instruction in instructions.iter() { if self.is_no_predicate(no_predicates, instruction) { // disable side effect for no_predicate functions @@ -332,10 +342,10 @@ impl<'f> Context<'f> { None, im::Vector::new(), ); - self.push_instruction(*instruction, &mut previous_allocate_result); + self.push_instruction(*instruction); self.insert_current_side_effects_enabled(); } else { - self.push_instruction(*instruction, &mut previous_allocate_result); + self.push_instruction(*instruction); } } } @@ -405,10 +415,12 @@ impl<'f> Context<'f> { let old_condition = *condition; let then_condition = self.inserter.resolve(old_condition); + let old_allocations = std::mem::take(&mut self.local_allocations); let branch = ConditionalBranch { old_condition, condition: self.link_condition(then_condition), last_block: *then_destination, + local_allocations: old_allocations, }; let cond_context = ConditionalContext { condition: then_condition, @@ -419,6 +431,16 @@ impl<'f> Context<'f> { }; self.condition_stack.push(cond_context); self.insert_current_side_effects_enabled(); + + // We disallow this case as it results in the `else_destination` block + // being inlined before the `then_destination` block due to block deduplication in the work queue. + // + // The `else_destination` block then gets treated as if it were the `then_destination` block + // and has the incorrect condition applied to it. + assert_ne!( + self.branch_ends[if_entry], *then_destination, + "ICE: branches merge inside of `then` branch" + ); vec![self.branch_ends[if_entry], *else_destination, *then_destination] } @@ -435,11 +457,14 @@ impl<'f> Context<'f> { ); let else_condition = self.link_condition(else_condition); + let old_allocations = std::mem::take(&mut self.local_allocations); let else_branch = ConditionalBranch { old_condition: cond_context.then_branch.old_condition, condition: else_condition, last_block: *block, + local_allocations: old_allocations, }; + cond_context.then_branch.local_allocations.clear(); cond_context.else_branch = Some(else_branch); self.condition_stack.push(cond_context); @@ -461,6 +486,7 @@ impl<'f> Context<'f> { } let mut else_branch = cond_context.else_branch.unwrap(); + self.local_allocations = std::mem::take(&mut else_branch.local_allocations); else_branch.last_block = *block; cond_context.else_branch = Some(else_branch); @@ -593,35 +619,27 @@ impl<'f> Context<'f> { /// `previous_allocate_result` should only be set to the result of an allocate instruction /// if that instruction was the instruction immediately previous to this one - if there are /// any instructions in between it should be None. - fn push_instruction( - &mut self, - id: InstructionId, - previous_allocate_result: &mut Option, - ) { + fn push_instruction(&mut self, id: InstructionId) { let (instruction, call_stack) = self.inserter.map_instruction(id); - let instruction = self.handle_instruction_side_effects( - instruction, - call_stack.clone(), - *previous_allocate_result, - ); + let instruction = self.handle_instruction_side_effects(instruction, call_stack.clone()); let instruction_is_allocate = matches!(&instruction, Instruction::Allocate); let entry = self.inserter.function.entry_block(); let results = self.inserter.push_instruction_value(instruction, id, entry, call_stack); - *previous_allocate_result = instruction_is_allocate.then(|| results.first()); + + // Remember an allocate was created local to this branch so that we do not try to merge store + // values across branches for it later. + if instruction_is_allocate { + self.local_allocations.insert(results.first()); + } } /// If we are currently in a branch, we need to modify constrain instructions /// to multiply them by the branch's condition (see optimization #1 in the module comment). - /// - /// `previous_allocate_result` should only be set to the result of an allocate instruction - /// if that instruction was the instruction immediately previous to this one - if there are - /// any instructions in between it should be None. fn handle_instruction_side_effects( &mut self, instruction: Instruction, call_stack: CallStack, - previous_allocate_result: Option, ) -> Instruction { if let Some(condition) = self.get_last_condition() { match instruction { @@ -652,7 +670,7 @@ impl<'f> Context<'f> { Instruction::Store { address, value } => { // If this instruction immediately follows an allocate, and stores to that // address there is no previous value to load and we don't need a merge anyway. - if Some(address) == previous_allocate_result { + if self.local_allocations.contains(&address) { Instruction::Store { address, value } } else { // Instead of storing `value`, store `if condition { value } else { previous_value }` @@ -1259,6 +1277,7 @@ mod test { fn should_not_merge_incorrectly_to_false() { // Regression test for #1792 // Tests that it does not simplify a true constraint an always-false constraint + let src = " acir(inline) fn main f0 { b0(v0: [u8; 2]): @@ -1463,4 +1482,23 @@ mod test { _ => unreachable!("Should have terminator instruction"), } } + + #[test] + #[should_panic = "ICE: branches merge inside of `then` branch"] + fn panics_if_branches_merge_within_then_branch() { + //! This is a regression test for https://github.com/noir-lang/noir/issues/6620 + + let src = " + acir(inline) fn main f0 { + b0(v0: u1): + jmpif v0 then: b2, else: b1 + b2(): + return + b1(): + jmp b2() + } + "; + let merged_ssa = Ssa::from_str(src).unwrap(); + let _ = merged_ssa.flatten_cfg(); + } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/capacity_tracker.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/capacity_tracker.rs index ddc8b0bfe6b..a01be691778 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/capacity_tracker.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/capacity_tracker.rs @@ -21,7 +21,7 @@ impl<'a> SliceCapacityTracker<'a> { pub(crate) fn collect_slice_information( &self, instruction: &Instruction, - slice_sizes: &mut HashMap, + slice_sizes: &mut HashMap, results: &[ValueId], ) { match instruction { @@ -106,13 +106,12 @@ impl<'a> SliceCapacityTracker<'a> { Intrinsic::ToBits(_) => { // Compiler sanity check assert!(matches!(self.dfg.type_of_value(result_slice), Type::Slice(_))); - slice_sizes.insert(result_slice, FieldElement::max_num_bits() as usize); + slice_sizes.insert(result_slice, FieldElement::max_num_bits()); } Intrinsic::ToRadix(_) => { // Compiler sanity check assert!(matches!(self.dfg.type_of_value(result_slice), Type::Slice(_))); - slice_sizes - .insert(result_slice, FieldElement::max_num_bytes() as usize); + slice_sizes.insert(result_slice, FieldElement::max_num_bytes()); } Intrinsic::AsSlice => { let array_size = self @@ -157,7 +156,7 @@ impl<'a> SliceCapacityTracker<'a> { pub(crate) fn compute_slice_capacity( &self, array_id: ValueId, - slice_sizes: &mut HashMap, + slice_sizes: &mut HashMap, ) { if let Some((array, typ)) = self.dfg.get_array_constant(array_id) { // Compiler sanity check @@ -165,7 +164,7 @@ impl<'a> SliceCapacityTracker<'a> { if let Type::Slice(_) = typ { let element_size = typ.element_size(); let len = array.len() / element_size; - slice_sizes.insert(array_id, len); + slice_sizes.insert(array_id, len as u32); } } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs index 8ea26d4e96d..c97572251db 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs @@ -17,7 +17,7 @@ pub(crate) struct ValueMerger<'a> { // Maps SSA array values with a slice type to their size. // This must be computed before merging values. - slice_sizes: &'a mut HashMap, + slice_sizes: &'a mut HashMap, array_set_conditionals: &'a mut HashMap, @@ -28,7 +28,7 @@ impl<'a> ValueMerger<'a> { pub(crate) fn new( dfg: &'a mut DataFlowGraph, block: BasicBlockId, - slice_sizes: &'a mut HashMap, + slice_sizes: &'a mut HashMap, array_set_conditionals: &'a mut HashMap, current_condition: Option, call_stack: CallStack, @@ -162,7 +162,7 @@ impl<'a> ValueMerger<'a> { _ => panic!("Expected array type"), }; - let actual_length = len * element_types.len(); + let actual_length = len * element_types.len() as u32; if let Some(result) = self.try_merge_only_changed_indices( then_condition, @@ -175,7 +175,8 @@ impl<'a> ValueMerger<'a> { for i in 0..len { for (element_index, element_type) in element_types.iter().enumerate() { - let index = ((i * element_types.len() + element_index) as u128).into(); + let index = + ((i * element_types.len() as u32 + element_index as u32) as u128).into(); let index = self.dfg.make_constant(index, Type::field()); let typevars = Some(vec![element_type.clone()]); @@ -222,22 +223,22 @@ impl<'a> ValueMerger<'a> { let (slice, typ) = self.dfg.get_array_constant(then_value_id).unwrap_or_else(|| { panic!("ICE: Merging values during flattening encountered slice {then_value_id} without a preset size"); }); - slice.len() / typ.element_types().len() + (slice.len() / typ.element_types().len()) as u32 }); let else_len = self.slice_sizes.get(&else_value_id).copied().unwrap_or_else(|| { let (slice, typ) = self.dfg.get_array_constant(else_value_id).unwrap_or_else(|| { panic!("ICE: Merging values during flattening encountered slice {else_value_id} without a preset size"); }); - slice.len() / typ.element_types().len() + (slice.len() / typ.element_types().len()) as u32 }); let len = then_len.max(else_len); for i in 0..len { for (element_index, element_type) in element_types.iter().enumerate() { - let index_usize = i * element_types.len() + element_index; - let index_value = (index_usize as u128).into(); + let index_u32 = i * element_types.len() as u32 + element_index as u32; + let index_value = (index_u32 as u128).into(); let index = self.dfg.make_constant(index_value, Type::field()); let typevars = Some(vec![element_type.clone()]); @@ -245,7 +246,7 @@ impl<'a> ValueMerger<'a> { let mut get_element = |array, typevars, len| { // The smaller slice is filled with placeholder data. Codegen for slice accesses must // include checks against the dynamic slice length so that this placeholder data is not incorrectly accessed. - if len <= index_usize { + if len <= index_u32 { self.make_slice_dummy_data(element_type) } else { let get = Instruction::ArrayGet { array, index }; @@ -260,10 +261,13 @@ impl<'a> ValueMerger<'a> { } }; - let then_element = - get_element(then_value_id, typevars.clone(), then_len * element_types.len()); + let then_element = get_element( + then_value_id, + typevars.clone(), + then_len * element_types.len() as u32, + ); let else_element = - get_element(else_value_id, typevars, else_len * element_types.len()); + get_element(else_value_id, typevars, else_len * element_types.len() as u32); merged.push_back(self.merge_values(then_condition, then_element, else_element)); } @@ -316,7 +320,7 @@ impl<'a> ValueMerger<'a> { then_condition: ValueId, then_value: ValueId, else_value: ValueId, - array_length: usize, + array_length: u32, ) -> Option { let mut found = false; let current_condition = self.current_condition?; @@ -370,7 +374,7 @@ impl<'a> ValueMerger<'a> { .chain(seen_else.into_iter().map(|(_, index, typ, condition)| (index, typ, condition))) .collect(); - if !found || changed_indices.len() >= array_length { + if !found || changed_indices.len() as u32 >= array_length { return None; } diff --git a/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs b/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs index 14233ca73e5..2e5c44cce76 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs @@ -7,14 +7,16 @@ //! - Already marked as loop invariants //! //! We also check that we are not hoisting instructions with side effects. -use fxhash::FxHashSet as HashSet; +use acvm::{acir::AcirField, FieldElement}; +use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; use crate::ssa::{ ir::{ basic_block::BasicBlockId, function::{Function, RuntimeType}, function_inserter::FunctionInserter, - instruction::InstructionId, + instruction::{Instruction, InstructionId}, + types::Type, value::ValueId, }, Ssa, @@ -45,25 +47,51 @@ impl Function { } impl Loops { - fn hoist_loop_invariants(self, function: &mut Function) { + fn hoist_loop_invariants(mut self, function: &mut Function) { let mut context = LoopInvariantContext::new(function); - for loop_ in self.yet_to_unroll.iter() { + // The loops should be sorted by the number of blocks. + // We want to access outer nested loops first, which we do by popping + // from the top of the list. + while let Some(loop_) = self.yet_to_unroll.pop() { let Ok(pre_header) = loop_.get_pre_header(context.inserter.function, &self.cfg) else { // If the loop does not have a preheader we skip hoisting loop invariants for this loop continue; }; - context.hoist_loop_invariants(loop_, pre_header); + + context.hoist_loop_invariants(&loop_, pre_header); } context.map_dependent_instructions(); } } +impl Loop { + /// Find the value that controls whether to perform a loop iteration. + /// This is going to be the block parameter of the loop header. + /// + /// Consider the following example of a `for i in 0..4` loop: + /// ```text + /// brillig(inline) fn main f0 { + /// b0(v0: u32): + /// ... + /// jmp b1(u32 0) + /// b1(v1: u32): // Loop header + /// v5 = lt v1, u32 4 // Upper bound + /// jmpif v5 then: b3, else: b2 + /// ``` + /// In the example above, `v1` is the induction variable + fn get_induction_variable(&self, function: &Function) -> ValueId { + function.dfg.block_parameters(self.header)[0] + } +} + struct LoopInvariantContext<'f> { inserter: FunctionInserter<'f>, defined_in_loop: HashSet, loop_invariants: HashSet, + // Maps induction variable -> fixed upper loop bound + outer_induction_variables: HashMap, } impl<'f> LoopInvariantContext<'f> { @@ -72,6 +100,7 @@ impl<'f> LoopInvariantContext<'f> { inserter: FunctionInserter::new(function), defined_in_loop: HashSet::default(), loop_invariants: HashSet::default(), + outer_induction_variables: HashMap::default(), } } @@ -88,13 +117,29 @@ impl<'f> LoopInvariantContext<'f> { self.inserter.push_instruction(instruction_id, *block); } - self.update_values_defined_in_loop_and_invariants(instruction_id, hoist_invariant); + self.extend_values_defined_in_loop_and_invariants(instruction_id, hoist_invariant); } } + + // Keep track of a loop induction variable and respective upper bound. + // This will be used by later loops to determine whether they have operations + // reliant upon the maximum induction variable. + let upper_bound = loop_.get_const_upper_bound(self.inserter.function); + if let Some(upper_bound) = upper_bound { + let induction_variable = loop_.get_induction_variable(self.inserter.function); + let induction_variable = self.inserter.resolve(induction_variable); + self.outer_induction_variables.insert(induction_variable, upper_bound); + } } /// Gather the variables declared within the loop fn set_values_defined_in_loop(&mut self, loop_: &Loop) { + // Clear any values that may be defined in previous loops, as the context is per function. + self.defined_in_loop.clear(); + // These are safe to keep per function, but we want to be clear that these values + // are used per loop. + self.loop_invariants.clear(); + for block in loop_.blocks.iter() { let params = self.inserter.function.dfg.block_parameters(*block); self.defined_in_loop.extend(params); @@ -107,7 +152,7 @@ impl<'f> LoopInvariantContext<'f> { /// Update any values defined in the loop and loop invariants after a /// analyzing and re-inserting a loop's instruction. - fn update_values_defined_in_loop_and_invariants( + fn extend_values_defined_in_loop_and_invariants( &mut self, instruction_id: InstructionId, hoist_invariant: bool, @@ -143,9 +188,45 @@ impl<'f> LoopInvariantContext<'f> { is_loop_invariant &= !self.defined_in_loop.contains(&value) || self.loop_invariants.contains(&value); }); - is_loop_invariant && instruction.can_be_deduplicated(&self.inserter.function.dfg, false) + + let can_be_deduplicated = instruction + .can_be_deduplicated(&self.inserter.function.dfg, false) + || self.can_be_deduplicated_from_upper_bound(&instruction); + + is_loop_invariant && can_be_deduplicated + } + + /// Certain instructions can take advantage of that our induction variable has a fixed maximum. + /// + /// For example, an array access can usually only be safely deduplicated when we have a constant + /// index that is below the length of the array. + /// Checking an array get where the index is the loop's induction variable on its own + /// would determine that the instruction is not safe for hoisting. + /// However, if we know that the induction variable's upper bound will always be in bounds of the array + /// we can safely hoist the array access. + fn can_be_deduplicated_from_upper_bound(&self, instruction: &Instruction) -> bool { + match instruction { + Instruction::ArrayGet { array, index } => { + let array_typ = self.inserter.function.dfg.type_of_value(*array); + let upper_bound = self.outer_induction_variables.get(index); + if let (Type::Array(_, len), Some(upper_bound)) = (array_typ, upper_bound) { + upper_bound.to_u128() <= len.into() + } else { + false + } + } + _ => false, + } } + /// Loop invariant hoisting only operates over loop instructions. + /// The `FunctionInserter` is used for mapping old values to new values after + /// re-inserting loop invariant instructions. + /// However, there may be instructions which are not within loops that are + /// still reliant upon the instruction results altered during the pass. + /// This method re-inserts all instructions so that all instructions have + /// correct new value IDs based upon the `FunctionInserter` internal map. + /// Leaving out this mapping could lead to instructions with values that do not exist. fn map_dependent_instructions(&mut self) { let blocks = self.inserter.function.reachable_blocks(); for block in blocks { @@ -375,4 +456,108 @@ mod test { // The code should be unchanged assert_normalized_ssa_equals(ssa, src); } + + #[test] + fn hoist_array_gets_using_induction_variable_with_const_bound() { + // SSA for the following program: + // + // fn triple_loop(x: u32) { + // let arr = [2; 5]; + // for i in 0..4 { + // for j in 0..4 { + // for _ in 0..4 { + // assert_eq(arr[i], x); + // assert_eq(arr[j], x); + // } + // } + // } + // } + // + // `arr[i]` and `arr[j]` are safe to hoist as we know the maximum possible index + // to be used for both array accesses. + // We want to make sure `arr[i]` is hoisted to the outermost loop body and that + // `arr[j]` is hoisted to the second outermost loop body. + let src = " + brillig(inline) fn main f0 { + b0(v0: u32, v1: u32): + v6 = make_array [u32 2, u32 2, u32 2, u32 2, u32 2] : [u32; 5] + inc_rc v6 + jmp b1(u32 0) + b1(v2: u32): + v9 = lt v2, u32 4 + jmpif v9 then: b3, else: b2 + b3(): + jmp b4(u32 0) + b4(v3: u32): + v10 = lt v3, u32 4 + jmpif v10 then: b6, else: b5 + b6(): + jmp b7(u32 0) + b7(v4: u32): + v13 = lt v4, u32 4 + jmpif v13 then: b9, else: b8 + b9(): + v15 = array_get v6, index v2 -> u32 + v16 = eq v15, v0 + constrain v15 == v0 + v17 = array_get v6, index v3 -> u32 + v18 = eq v17, v0 + constrain v17 == v0 + v19 = add v4, u32 1 + jmp b7(v19) + b8(): + v14 = add v3, u32 1 + jmp b4(v14) + b5(): + v12 = add v2, u32 1 + jmp b1(v12) + b2(): + return + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + + let expected = " + brillig(inline) fn main f0 { + b0(v0: u32, v1: u32): + v6 = make_array [u32 2, u32 2, u32 2, u32 2, u32 2] : [u32; 5] + inc_rc v6 + jmp b1(u32 0) + b1(v2: u32): + v9 = lt v2, u32 4 + jmpif v9 then: b3, else: b2 + b3(): + v10 = array_get v6, index v2 -> u32 + v11 = eq v10, v0 + jmp b4(u32 0) + b4(v3: u32): + v12 = lt v3, u32 4 + jmpif v12 then: b6, else: b5 + b6(): + v15 = array_get v6, index v3 -> u32 + v16 = eq v15, v0 + jmp b7(u32 0) + b7(v4: u32): + v17 = lt v4, u32 4 + jmpif v17 then: b9, else: b8 + b9(): + constrain v10 == v0 + constrain v15 == v0 + v19 = add v4, u32 1 + jmp b7(v19) + b8(): + v18 = add v3, u32 1 + jmp b4(v18) + b5(): + v14 = add v2, u32 1 + jmp b1(v14) + b2(): + return + } + "; + + let ssa = ssa.loop_invariant_code_motion(); + assert_normalized_ssa_equals(ssa, expected); + } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs index 77133d7d88d..53a31ae57c1 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs @@ -18,6 +18,7 @@ //! - A reference with 0 aliases means we were unable to find which reference this reference //! refers to. If such a reference is stored to, we must conservatively invalidate every //! reference in the current block. +//! - We also track the last load instruction to each address per block. //! //! From there, to figure out the value of each reference at the end of block, iterate each instruction: //! - On `Instruction::Allocate`: @@ -28,6 +29,13 @@ //! - Furthermore, if the result of the load is a reference, mark the result as an alias //! of the reference it dereferences to (if known). //! - If which reference it dereferences to is not known, this load result has no aliases. +//! - We also track the last instance of a load instruction to each address in a block. +//! If we see that the last load instruction was from the same address as the current load instruction, +//! we move to replace the result of the current load with the result of the previous load. +//! This removal requires a couple conditions: +//! - No store occurs to that address before the next load, +//! - The address is not used as an argument to a call +//! This optimization helps us remove repeated loads for which there are not known values. //! - On `Instruction::Store { address, value }`: //! - If the address of the store is known: //! - If the address has exactly 1 alias: @@ -40,11 +48,13 @@ //! - Conservatively mark every alias in the block to `Unknown`. //! - Additionally, if there were no Loads to any alias of the address between this Store and //! the previous Store to the same address, the previous store can be removed. +//! - Remove the instance of the last load instruction to the address and its aliases //! - On `Instruction::Call { arguments }`: //! - If any argument of the call is a reference, set the value of each alias of that //! reference to `Unknown` //! - Any builtin functions that may return aliases if their input also contains a //! reference should be tracked. Examples: `slice_push_back`, `slice_insert`, `slice_remove`, etc. +//! - Remove the instance of the last load instruction for any reference arguments and their aliases //! //! On a terminator instruction: //! - If the terminator is a `Jmp`: @@ -274,6 +284,9 @@ impl<'f> PerFunctionContext<'f> { if let Some(first_predecessor) = predecessors.next() { let mut first = self.blocks.get(&first_predecessor).cloned().unwrap_or_default(); first.last_stores.clear(); + // Last loads are tracked per block. During unification we are creating a new block from the current one, + // so we must clear the last loads of the current block before we return the new block. + first.last_loads.clear(); // Note that we have to start folding with the first block as the accumulator. // If we started with an empty block, an empty block union'd with any other block @@ -410,18 +423,38 @@ impl<'f> PerFunctionContext<'f> { self.last_loads.insert(address, (instruction, block_id)); } + + // Check whether the block has a repeat load from the same address (w/ no calls or stores in between the loads). + // If we do have a repeat load, we can remove the current load and map its result to the previous load's result. + if let Some(last_load) = references.last_loads.get(&address) { + let Instruction::Load { address: previous_address } = + &self.inserter.function.dfg[*last_load] + else { + panic!("Expected a Load instruction here"); + }; + let result = self.inserter.function.dfg.instruction_results(instruction)[0]; + let previous_result = + self.inserter.function.dfg.instruction_results(*last_load)[0]; + if *previous_address == address { + self.inserter.map_value(result, previous_result); + self.instructions_to_remove.insert(instruction); + } + } + // We want to set the load for every load even if the address has a known value + // and the previous load instruction was removed. + // We are safe to still remove a repeat load in this case as we are mapping from the current load's + // result to the previous load, which if it was removed should already have a mapping to the known value. + references.set_last_load(address, instruction); } Instruction::Store { address, value } => { let address = self.inserter.function.dfg.resolve(*address); let value = self.inserter.function.dfg.resolve(*value); - // FIXME: This causes errors in the sha256 tests - // // If there was another store to this instruction without any (unremoved) loads or // function calls in-between, we can remove the previous store. - // if let Some(last_store) = references.last_stores.get(&address) { - // self.instructions_to_remove.insert(*last_store); - // } + if let Some(last_store) = references.last_stores.get(&address) { + self.instructions_to_remove.insert(*last_store); + } if self.inserter.function.dfg.value_is_reference(value) { if let Some(expression) = references.expressions.get(&value) { @@ -437,6 +470,8 @@ impl<'f> PerFunctionContext<'f> { } references.set_known_value(address, value); + // If we see a store to an address, the last load to that address needs to remain. + references.keep_last_load_for(address, self.inserter.function); references.last_stores.insert(address, instruction); } Instruction::Allocate => { @@ -544,6 +579,9 @@ impl<'f> PerFunctionContext<'f> { let value = self.inserter.function.dfg.resolve(*value); references.set_unknown(value); references.mark_value_used(value, self.inserter.function); + + // If a reference is an argument to a call, the last load to that address and its aliases needs to remain. + references.keep_last_load_for(value, self.inserter.function); } } } @@ -574,6 +612,12 @@ impl<'f> PerFunctionContext<'f> { let destination_parameters = self.inserter.function.dfg[*destination].parameters(); assert_eq!(destination_parameters.len(), arguments.len()); + // If we have multiple parameters that alias that same argument value, + // then those parameters also alias each other. + // We save parameters with repeat arguments to later mark those + // parameters as aliasing one another. + let mut arg_set: HashMap> = HashMap::default(); + // Add an alias for each reference parameter for (parameter, argument) in destination_parameters.iter().zip(arguments) { if self.inserter.function.dfg.value_is_reference(*parameter) { @@ -583,10 +627,27 @@ impl<'f> PerFunctionContext<'f> { if let Some(aliases) = references.aliases.get_mut(expression) { // The argument reference is possibly aliased by this block parameter aliases.insert(*parameter); + + // Check if we have seen the same argument + let seen_parameters = arg_set.entry(argument).or_default(); + // Add the current parameter to the parameters we have seen for this argument. + // The previous parameters and the current one alias one another. + seen_parameters.insert(*parameter); } } } } + + // Set the aliases of the parameters + for (_, aliased_params) in arg_set { + for param in aliased_params.iter() { + self.set_aliases( + references, + *param, + AliasSet::known_multiple(aliased_params.clone()), + ); + } + } } TerminatorInstruction::Return { return_values, .. } => { // Removing all `last_stores` for each returned reference is more important here @@ -614,6 +675,8 @@ mod tests { map::Id, types::Type, }, + opt::assert_normalized_ssa_equals, + Ssa, }; #[test] @@ -824,91 +887,53 @@ mod tests { // is later stored in a successor block #[test] fn load_aliases_in_predecessor_block() { - // fn main { - // b0(): - // v0 = allocate - // store Field 0 at v0 - // v2 = allocate - // store v0 at v2 - // v3 = load v2 - // v4 = load v2 - // jmp b1() - // b1(): - // store Field 1 at v3 - // store Field 2 at v4 - // v7 = load v3 - // v8 = eq v7, Field 2 - // return - // } - let main_id = Id::test_new(0); - let mut builder = FunctionBuilder::new("main".into(), main_id); - - let v0 = builder.insert_allocate(Type::field()); - - let zero = builder.field_constant(0u128); - builder.insert_store(v0, zero); - - let v2 = builder.insert_allocate(Type::Reference(Arc::new(Type::field()))); - builder.insert_store(v2, v0); - - let v3 = builder.insert_load(v2, Type::field()); - let v4 = builder.insert_load(v2, Type::field()); - let b1 = builder.insert_block(); - builder.terminate_with_jmp(b1, vec![]); - - builder.switch_to_block(b1); - - let one = builder.field_constant(1u128); - builder.insert_store(v3, one); - - let two = builder.field_constant(2u128); - builder.insert_store(v4, two); - - let v8 = builder.insert_load(v3, Type::field()); - let _ = builder.insert_binary(v8, BinaryOp::Eq, two); - - builder.terminate_with_return(vec![]); - - let ssa = builder.finish(); - assert_eq!(ssa.main().reachable_blocks().len(), 2); + let src = " + acir(inline) fn main f0 { + b0(): + v0 = allocate -> &mut Field + store Field 0 at v0 + v2 = allocate -> &mut &mut Field + store v0 at v2 + v3 = load v2 -> &mut Field + v4 = load v2 -> &mut Field + jmp b1() + b1(): + store Field 1 at v3 + store Field 2 at v4 + v7 = load v3 -> Field + v8 = eq v7, Field 2 + return + } + "; - // Expected result: - // acir fn main f0 { - // b0(): - // v9 = allocate - // store Field 0 at v9 - // v10 = allocate - // jmp b1() - // b1(): - // return - // } - let ssa = ssa.mem2reg(); - println!("{}", ssa); + let mut ssa = Ssa::from_str(src).unwrap(); + let main = ssa.main_mut(); - let main = ssa.main(); - assert_eq!(main.reachable_blocks().len(), 2); + let instructions = main.dfg[main.entry_block()].instructions(); + assert_eq!(instructions.len(), 6); // The final return is not counted // All loads should be removed - assert_eq!(count_loads(main.entry_block(), &main.dfg), 0); - assert_eq!(count_loads(b1, &main.dfg), 0); - // The first store is not removed as it is used as a nested reference in another store. - // We would need to track whether the store where `v9` is the store value gets removed to know whether + // We would need to track whether the store where `v0` is the store value gets removed to know whether // to remove it. - assert_eq!(count_stores(main.entry_block(), &main.dfg), 1); - // The first store in b1 is removed since there is another store to the same reference // in the same block, and the store is not needed before the later store. // The rest of the stores are also removed as no loads are done within any blocks // to the stored values. - // - // NOTE: This store is not removed due to the FIXME when handling Instruction::Store. - assert_eq!(count_stores(b1, &main.dfg), 1); - - let b1_instructions = main.dfg[b1].instructions(); + let expected = " + acir(inline) fn main f0 { + b0(): + v0 = allocate -> &mut Field + store Field 0 at v0 + v2 = allocate -> &mut &mut Field + jmp b1() + b1(): + return + } + "; - // We expect the last eq to be optimized out, only the store from above remains - assert_eq!(b1_instructions.len(), 1); + let ssa = ssa.mem2reg(); + assert_normalized_ssa_equals(ssa, expected); } #[test] @@ -938,7 +963,7 @@ mod tests { // v10 = eq v9, Field 2 // constrain v9 == Field 2 // v11 = load v2 - // v12 = load v10 + // v12 = load v11 // v13 = eq v12, Field 2 // constrain v11 == Field 2 // return @@ -997,7 +1022,7 @@ mod tests { let main = ssa.main(); assert_eq!(main.reachable_blocks().len(), 4); - // The store from the original SSA should remain + // The stores from the original SSA should remain assert_eq!(count_stores(main.entry_block(), &main.dfg), 2); assert_eq!(count_stores(b2, &main.dfg), 1); @@ -1044,4 +1069,160 @@ mod tests { let main = ssa.main(); assert_eq!(count_loads(main.entry_block(), &main.dfg), 1); } + + #[test] + fn remove_repeat_loads() { + // This tests starts with two loads from the same unknown load. + // Specifically you should look for `load v2` in `b3`. + // We should be able to remove the second repeated load. + let src = " + acir(inline) fn main f0 { + b0(): + v0 = allocate -> &mut Field + store Field 0 at v0 + v2 = allocate -> &mut &mut Field + store v0 at v2 + jmp b1(Field 0) + b1(v3: Field): + v4 = eq v3, Field 0 + jmpif v4 then: b2, else: b3 + b2(): + v5 = load v2 -> &mut Field + store Field 2 at v5 + v8 = add v3, Field 1 + jmp b1(v8) + b3(): + v9 = load v0 -> Field + v10 = eq v9, Field 2 + constrain v9 == Field 2 + v11 = load v2 -> &mut Field + v12 = load v2 -> &mut Field + v13 = load v12 -> Field + v14 = eq v13, Field 2 + constrain v13 == Field 2 + return + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + + // The repeated load from v3 should be removed + // b3 should only have three loads now rather than four previously + // + // All stores are expected to remain. + let expected = " + acir(inline) fn main f0 { + b0(): + v1 = allocate -> &mut Field + store Field 0 at v1 + v3 = allocate -> &mut &mut Field + store v1 at v3 + jmp b1(Field 0) + b1(v0: Field): + v4 = eq v0, Field 0 + jmpif v4 then: b3, else: b2 + b3(): + v11 = load v3 -> &mut Field + store Field 2 at v11 + v13 = add v0, Field 1 + jmp b1(v13) + b2(): + v5 = load v1 -> Field + v7 = eq v5, Field 2 + constrain v5 == Field 2 + v8 = load v3 -> &mut Field + v9 = load v8 -> Field + v10 = eq v9, Field 2 + constrain v9 == Field 2 + return + } + "; + + let ssa = ssa.mem2reg(); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn keep_repeat_loads_passed_to_a_call() { + // The test is the exact same as `remove_repeat_loads` above except with the call + // to `f1` between the repeated loads. + let src = " + acir(inline) fn main f0 { + b0(): + v1 = allocate -> &mut Field + store Field 0 at v1 + v3 = allocate -> &mut &mut Field + store v1 at v3 + jmp b1(Field 0) + b1(v0: Field): + v4 = eq v0, Field 0 + jmpif v4 then: b3, else: b2 + b3(): + v13 = load v3 -> &mut Field + store Field 2 at v13 + v15 = add v0, Field 1 + jmp b1(v15) + b2(): + v5 = load v1 -> Field + v7 = eq v5, Field 2 + constrain v5 == Field 2 + v8 = load v3 -> &mut Field + call f1(v3) + v10 = load v3 -> &mut Field + v11 = load v10 -> Field + v12 = eq v11, Field 2 + constrain v11 == Field 2 + return + } + acir(inline) fn foo f1 { + b0(v0: &mut Field): + return + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + + let ssa = ssa.mem2reg(); + // We expect the program to be unchanged + assert_normalized_ssa_equals(ssa, src); + } + + #[test] + fn keep_repeat_loads_with_alias_store() { + // v7, v8, and v9 alias one another. We want to make sure that a repeat load to v7 with a store + // to its aliases in between the repeat loads does not remove those loads. + let src = " + acir(inline) fn main f0 { + b0(v0: u1): + jmpif v0 then: b2, else: b1 + b2(): + v6 = allocate -> &mut Field + store Field 0 at v6 + jmp b3(v6, v6, v6) + b3(v1: &mut Field, v2: &mut Field, v3: &mut Field): + v8 = load v1 -> Field + store Field 2 at v2 + v10 = load v1 -> Field + store Field 1 at v3 + v11 = load v1 -> Field + store Field 3 at v3 + v13 = load v1 -> Field + constrain v8 == Field 0 + constrain v10 == Field 2 + constrain v11 == Field 1 + constrain v13 == Field 3 + return + b1(): + v4 = allocate -> &mut Field + store Field 1 at v4 + jmp b3(v4, v4, v4) + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + + let ssa = ssa.mem2reg(); + // We expect the program to be unchanged + assert_normalized_ssa_equals(ssa, src); + } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg/alias_set.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg/alias_set.rs index 4d768caa36b..e32eaa70186 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg/alias_set.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg/alias_set.rs @@ -24,6 +24,10 @@ impl AliasSet { Self { aliases: Some(aliases) } } + pub(super) fn known_multiple(values: BTreeSet) -> AliasSet { + Self { aliases: Some(values) } + } + /// In rare cases, such as when creating an empty array of references, the set of aliases for a /// particular value will be known to be zero, which is distinct from being unknown and /// possibly referring to any alias. diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg/block.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg/block.rs index 532785d2928..f4265b2466d 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg/block.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg/block.rs @@ -34,6 +34,9 @@ pub(super) struct Block { /// The last instance of a `Store` instruction to each address in this block pub(super) last_stores: im::OrdMap, + + // The last instance of a `Load` instruction to each address in this block + pub(super) last_loads: im::OrdMap, } /// An `Expression` here is used to represent a canonical key @@ -237,4 +240,14 @@ impl Block { Cow::Owned(AliasSet::unknown()) } + + pub(super) fn set_last_load(&mut self, address: ValueId, instruction: InstructionId) { + self.last_loads.insert(address, instruction); + } + + pub(super) fn keep_last_load_for(&mut self, address: ValueId, function: &Function) { + let address = function.dfg.resolve(address); + self.last_loads.remove(&address); + self.for_each_alias_of(address, |block, alias| block.last_loads.remove(&alias)); + } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs index cdbb1043232..ccf5bd9d9f8 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs @@ -191,7 +191,7 @@ impl Context<'_> { let typ = self.function.dfg.type_of_value(rhs); if let Type::Numeric(NumericType::Unsigned { bit_size }) = typ { let to_bits = self.function.dfg.import_intrinsic(Intrinsic::ToBits(Endian::Little)); - let result_types = vec![Type::Array(Arc::new(vec![Type::bool()]), bit_size as usize)]; + let result_types = vec![Type::Array(Arc::new(vec![Type::bool()]), bit_size)]; let rhs_bits = self.insert_call(to_bits, vec![rhs], result_types); let rhs_bits = rhs_bits[0]; diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs index 8e25c3f0a35..182e6e54d0f 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs @@ -48,7 +48,7 @@ impl Function { #[derive(Default)] struct Context { - slice_sizes: HashMap, + slice_sizes: HashMap, // Maps array_set result -> element that was overwritten by that instruction. // Used to undo array_sets while merging values @@ -142,13 +142,13 @@ impl Context { } } - fn get_or_find_capacity(&mut self, dfg: &DataFlowGraph, value: ValueId) -> usize { + fn get_or_find_capacity(&mut self, dfg: &DataFlowGraph, value: ValueId) -> u32 { match self.slice_sizes.entry(value) { Entry::Occupied(entry) => return *entry.get(), Entry::Vacant(entry) => { if let Some((array, typ)) = dfg.get_array_constant(value) { let length = array.len() / typ.element_types().len(); - return *entry.insert(length); + return *entry.insert(length as u32); } if let Type::Array(_, length) = dfg.type_of_value(value) { @@ -164,7 +164,7 @@ impl Context { enum SizeChange { None, - SetTo(ValueId, usize), + SetTo(ValueId, u32), // These two variants store the old and new slice ids // not their lengths which should be old_len = new_len +/- 1 diff --git a/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs index 46941775c5e..c282e2df451 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs @@ -18,7 +18,8 @@ use crate::ssa::{ basic_block::BasicBlockId, cfg::ControlFlowGraph, function::{Function, RuntimeType}, - instruction::TerminatorInstruction, + instruction::{Instruction, TerminatorInstruction}, + value::Value, }, ssa_gen::Ssa, }; @@ -31,6 +32,7 @@ impl Ssa { /// 4. Removing any blocks which have no instructions other than a single terminating jmp. /// 5. Replacing any jmpifs with constant conditions with jmps. If this causes the block to have /// only 1 successor then (2) also will be applied. + /// 6. Replacing any jmpifs with a negated condition with a jmpif with a un-negated condition and reversed branches. /// /// Currently, 1 is unimplemented. #[tracing::instrument(level = "trace", skip(self))] @@ -55,6 +57,8 @@ impl Function { stack.extend(self.dfg[block].successors().filter(|block| !visited.contains(block))); } + check_for_negated_jmpif_condition(self, block, &mut cfg); + // This call is before try_inline_into_predecessor so that if it succeeds in changing a // jmpif into a jmp, the block may then be inlined entirely into its predecessor in try_inline_into_predecessor. check_for_constant_jmpif(self, block, &mut cfg); @@ -184,6 +188,55 @@ fn check_for_double_jmp(function: &mut Function, block: BasicBlockId, cfg: &mut cfg.recompute_block(function, block); } +/// Optimize a jmpif on a negated condition by swapping the branches. +fn check_for_negated_jmpif_condition( + function: &mut Function, + block: BasicBlockId, + cfg: &mut ControlFlowGraph, +) { + if matches!(function.runtime(), RuntimeType::Acir(_)) { + // Swapping the `then` and `else` branches of a `JmpIf` within an ACIR function + // can result in the situation where the branches merge together again in the `then` block, e.g. + // + // acir(inline) fn main f0 { + // b0(v0: u1): + // jmpif v0 then: b2, else: b1 + // b2(): + // return + // b1(): + // jmp b2() + // } + // + // This breaks the `flatten_cfg` pass as it assumes that merges only happen in + // the `else` block or a 3rd block. + // + // See: https://github.com/noir-lang/noir/pull/5891#issuecomment-2500219428 + return; + } + + if let Some(TerminatorInstruction::JmpIf { + condition, + then_destination, + else_destination, + call_stack, + }) = function.dfg[block].terminator() + { + if let Value::Instruction { instruction, .. } = function.dfg[*condition] { + if let Instruction::Not(negated_condition) = function.dfg[instruction] { + let call_stack = call_stack.clone(); + let jmpif = TerminatorInstruction::JmpIf { + condition: negated_condition, + then_destination: *else_destination, + else_destination: *then_destination, + call_stack, + }; + function.dfg[block].set_terminator(jmpif); + cfg.recompute_block(function, block); + } + } + } +} + /// If the given block has block parameters, replace them with the jump arguments from the predecessor. /// /// Currently, if this function is needed, `try_inline_into_predecessor` will also always apply, @@ -246,6 +299,8 @@ mod test { map::Id, types::Type, }, + opt::assert_normalized_ssa_equals, + Ssa, }; use acvm::acir::AcirField; @@ -359,4 +414,59 @@ mod test { other => panic!("Unexpected terminator {other:?}"), } } + + #[test] + fn swap_negated_jmpif_branches_in_brillig() { + let src = " + brillig(inline) fn main f0 { + b0(v0: u1): + v1 = allocate -> &mut Field + store Field 0 at v1 + v3 = not v0 + jmpif v3 then: b1, else: b2 + b1(): + store Field 2 at v1 + jmp b2() + b2(): + v5 = load v1 -> Field + v6 = eq v5, Field 2 + constrain v5 == Field 2 + return + }"; + let ssa = Ssa::from_str(src).unwrap(); + + let expected = " + brillig(inline) fn main f0 { + b0(v0: u1): + v1 = allocate -> &mut Field + store Field 0 at v1 + v3 = not v0 + jmpif v0 then: b2, else: b1 + b2(): + v5 = load v1 -> Field + v6 = eq v5, Field 2 + constrain v5 == Field 2 + return + b1(): + store Field 2 at v1 + jmp b2() + }"; + assert_normalized_ssa_equals(ssa.simplify_cfg(), expected); + } + + #[test] + fn does_not_swap_negated_jmpif_branches_in_acir() { + let src = " + acir(inline) fn main f0 { + b0(v0: u1): + v1 = not v0 + jmpif v1 then: b1, else: b2 + b1(): + jmp b2() + b2(): + return + }"; + let ssa = Ssa::from_str(src).unwrap(); + assert_normalized_ssa_equals(ssa.simplify_cfg(), src); + } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index 777c16dacd1..1a13acc5435 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -19,8 +19,10 @@ //! When unrolling ACIR code, we remove reference count instructions because they are //! only used by Brillig bytecode. use acvm::{acir::AcirField, FieldElement}; +use im::HashSet; use crate::{ + brillig::brillig_gen::convert_ssa_function, errors::RuntimeError, ssa::{ ir::{ @@ -37,38 +39,60 @@ use crate::{ ssa_gen::Ssa, }, }; -use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; +use fxhash::FxHashMap as HashMap; impl Ssa { /// Loop unrolling can return errors, since ACIR functions need to be fully unrolled. /// This meta-pass will keep trying to unroll loops and simplifying the SSA until no more errors are found. - #[tracing::instrument(level = "trace", skip(ssa))] - pub(crate) fn unroll_loops_iteratively(mut ssa: Ssa) -> Result { - for (_, function) in ssa.functions.iter_mut() { + /// + /// The `max_bytecode_incr_pct`, when given, is used to limit the growth of the Brillig bytecode size + /// after unrolling small loops to some percentage of the original loop. For example a value of 150 would + /// mean the new loop can be 150% (ie. 2.5 times) larger than the original loop. It will still contain + /// fewer SSA instructions, but that can still result in more Brillig opcodes. + #[tracing::instrument(level = "trace", skip(self))] + pub(crate) fn unroll_loops_iteratively( + mut self: Ssa, + max_bytecode_increase_percent: Option, + ) -> Result { + for (_, function) in self.functions.iter_mut() { + // Take a snapshot of the function to compare byte size increase, + // but only if the setting indicates we have to, otherwise skip it. + let orig_func_and_max_incr_pct = max_bytecode_increase_percent + .filter(|_| function.runtime().is_brillig()) + .map(|max_incr_pct| (function.clone(), max_incr_pct)); + // Try to unroll loops first: - let mut unroll_errors = function.try_unroll_loops(); + let (mut has_unrolled, mut unroll_errors) = function.try_unroll_loops(); // Keep unrolling until no more errors are found while !unroll_errors.is_empty() { let prev_unroll_err_count = unroll_errors.len(); // Simplify the SSA before retrying - - // Do a mem2reg after the last unroll to aid simplify_cfg - function.mem2reg(); - function.simplify_function(); - // Do another mem2reg after simplify_cfg to aid the next unroll - function.mem2reg(); + simplify_between_unrolls(function); // Unroll again - unroll_errors = function.try_unroll_loops(); + let (new_unrolled, new_errors) = function.try_unroll_loops(); + unroll_errors = new_errors; + has_unrolled |= new_unrolled; + // If we didn't manage to unroll any more loops, exit if unroll_errors.len() >= prev_unroll_err_count { return Err(unroll_errors.swap_remove(0)); } } + + if has_unrolled { + if let Some((orig_function, max_incr_pct)) = orig_func_and_max_incr_pct { + let new_size = brillig_bytecode_size(function); + let orig_size = brillig_bytecode_size(&orig_function); + if !is_new_size_ok(orig_size, new_size, max_incr_pct) { + *function = orig_function; + } + } + } } - Ok(ssa) + Ok(self) } } @@ -77,7 +101,7 @@ impl Function { // This can also be true for ACIR, but we have no alternative to unrolling in ACIR. // Brillig also generally prefers smaller code rather than faster code, // so we only attempt to unroll small loops, which we decide on a case-by-case basis. - fn try_unroll_loops(&mut self) -> Vec { + fn try_unroll_loops(&mut self) -> (bool, Vec) { Loops::find_all(self).unroll_each(self) } } @@ -85,7 +109,7 @@ impl Function { pub(super) struct Loop { /// The header block of a loop is the block which dominates all the /// other blocks in the loop. - header: BasicBlockId, + pub(super) header: BasicBlockId, /// The start of the back_edge n -> d is the block n at the end of /// the loop that jumps back to the header block d which restarts the loop. @@ -170,8 +194,10 @@ impl Loops { /// Unroll all loops within a given function. /// Any loops which fail to be unrolled (due to using non-constant indices) will be unmodified. - fn unroll_each(mut self, function: &mut Function) -> Vec { + /// Returns whether any blocks have been modified + fn unroll_each(mut self, function: &mut Function) -> (bool, Vec) { let mut unroll_errors = vec![]; + let mut has_unrolled = false; while let Some(next_loop) = self.yet_to_unroll.pop() { if function.runtime().is_brillig() && !next_loop.is_small_loop(function, &self.cfg) { continue; @@ -181,13 +207,17 @@ impl Loops { if next_loop.blocks.iter().any(|block| self.modified_blocks.contains(block)) { let mut new_loops = Self::find_all(function); new_loops.failed_to_unroll = self.failed_to_unroll; - return unroll_errors.into_iter().chain(new_loops.unroll_each(function)).collect(); + let (new_unrolled, new_errors) = new_loops.unroll_each(function); + return (has_unrolled || new_unrolled, [unroll_errors, new_errors].concat()); } // Don't try to unroll the loop again if it is known to fail if !self.failed_to_unroll.contains(&next_loop.header) { match next_loop.unroll(function, &self.cfg) { - Ok(_) => self.modified_blocks.extend(next_loop.blocks), + Ok(_) => { + has_unrolled = true; + self.modified_blocks.extend(next_loop.blocks); + } Err(call_stack) => { self.failed_to_unroll.insert(next_loop.header); unroll_errors.push(RuntimeError::UnknownLoopBound { call_stack }); @@ -195,7 +225,7 @@ impl Loops { } } } - unroll_errors + (has_unrolled, unroll_errors) } } @@ -269,7 +299,7 @@ impl Loop { /// v5 = lt v1, u32 4 // Upper bound /// jmpif v5 then: b3, else: b2 /// ``` - fn get_const_upper_bound(&self, function: &Function) -> Option { + pub(super) fn get_const_upper_bound(&self, function: &Function) -> Option { let block = &function.dfg[self.header]; let instructions = block.instructions(); assert_eq!( @@ -947,21 +977,59 @@ impl<'f> LoopIteration<'f> { } } +/// Unrolling leaves some duplicate instructions which can potentially be removed. +fn simplify_between_unrolls(function: &mut Function) { + // Do a mem2reg after the last unroll to aid simplify_cfg + function.mem2reg(); + function.simplify_function(); + // Do another mem2reg after simplify_cfg to aid the next unroll + function.mem2reg(); +} + +/// Convert the function to Brillig bytecode and return the resulting size. +fn brillig_bytecode_size(function: &Function) -> usize { + // We need to do some SSA passes in order for the conversion to be able to go ahead, + // otherwise we can hit `unreachable!()` instructions in `convert_ssa_instruction`. + // Creating a clone so as not to modify the originals. + let mut temp = function.clone(); + + // Might as well give it the best chance. + simplify_between_unrolls(&mut temp); + + // This is to try to prevent hitting ICE. + temp.dead_instruction_elimination(false); + + convert_ssa_function(&temp, false).byte_code.len() +} + +/// Decide if the new bytecode size is acceptable, compared to the original. +/// +/// The maximum increase can be expressed as a negative value if we demand a decrease. +/// (Values -100 and under mean the new size should be 0). +fn is_new_size_ok(orig_size: usize, new_size: usize, max_incr_pct: i32) -> bool { + let max_size_pct = 100i32.saturating_add(max_incr_pct).max(0) as usize; + let max_size = orig_size.saturating_mul(max_size_pct); + new_size.saturating_mul(100) <= max_size +} + #[cfg(test)] mod tests { use acvm::FieldElement; + use test_case::test_case; use crate::errors::RuntimeError; use crate::ssa::{ir::value::ValueId, opt::assert_normalized_ssa_equals, Ssa}; - use super::{BoilerplateStats, Loops}; + use super::{is_new_size_ok, BoilerplateStats, Loops}; - /// Tries to unroll all loops in each SSA function. + /// Tries to unroll all loops in each SSA function once, calling the `Function` directly, + /// bypassing the iterative loop done by the SSA which does further optimisations. + /// /// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state. fn try_unroll_loops(mut ssa: Ssa) -> (Ssa, Vec) { let mut errors = vec![]; for function in ssa.functions.values_mut() { - errors.extend(function.try_unroll_loops()); + errors.extend(function.try_unroll_loops().1); } (ssa, errors) } @@ -1221,9 +1289,26 @@ mod tests { let (ssa, errors) = try_unroll_loops(ssa); assert_eq!(errors.len(), 0, "Unroll should have no errors"); + // Check that it's still the original assert_normalized_ssa_equals(ssa, parse_ssa().to_string().as_str()); } + #[test] + fn test_brillig_unroll_iteratively_respects_max_increase() { + let ssa = brillig_unroll_test_case(); + let ssa = ssa.unroll_loops_iteratively(Some(-90)).unwrap(); + // Check that it's still the original + assert_normalized_ssa_equals(ssa, brillig_unroll_test_case().to_string().as_str()); + } + + #[test] + fn test_brillig_unroll_iteratively_with_large_max_increase() { + let ssa = brillig_unroll_test_case(); + let ssa = ssa.unroll_loops_iteratively(Some(50)).unwrap(); + // Check that it did the unroll + assert_eq!(ssa.main().reachable_blocks().len(), 2, "The loop should be unrolled"); + } + /// Test that `break` and `continue` stop unrolling without any panic. #[test] fn test_brillig_unroll_break_and_continue() { @@ -1377,4 +1462,14 @@ mod tests { let loop0 = loops.yet_to_unroll.pop().expect("there should be a loop"); loop0.boilerplate_stats(function, &loops.cfg).expect("there should be stats") } + + #[test_case(1000, 700, 50, true; "size decreased")] + #[test_case(1000, 1500, 50, true; "size increased just by the max")] + #[test_case(1000, 1501, 50, false; "size increased over the max")] + #[test_case(1000, 700, -50, false; "size decreased but not enough")] + #[test_case(1000, 250, -50, true; "size decreased over expectations")] + #[test_case(1000, 250, -1250, false; "demanding more than minus 100 is handled")] + fn test_is_new_size_ok(old: usize, new: usize, max: i32, ok: bool) { + assert_eq!(is_new_size_ok(old, new, max), ok); + } } diff --git a/compiler/noirc_evaluator/src/ssa/parser/mod.rs b/compiler/noirc_evaluator/src/ssa/parser/mod.rs index 3d8bd37dead..506d2df3dea 100644 --- a/compiler/noirc_evaluator/src/ssa/parser/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/parser/mod.rs @@ -656,7 +656,7 @@ impl<'a> Parser<'a> { if self.eat(Token::Semicolon)? { let length = self.eat_int_or_error()?; self.eat_or_error(Token::RightBracket)?; - return Ok(Type::Array(Arc::new(element_types), length.to_u128() as usize)); + return Ok(Type::Array(Arc::new(element_types), length.to_u128() as u32)); } else { self.eat_or_error(Token::RightBracket)?; return Ok(Type::Slice(Arc::new(element_types))); diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs index ddc3365b551..8a09dba64c4 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs @@ -237,12 +237,12 @@ impl<'a> FunctionContext<'a> { ast::Type::Field => Type::field(), ast::Type::Array(len, element) => { let element_types = Self::convert_type(element).flatten(); - Type::Array(Arc::new(element_types), *len as usize) + Type::Array(Arc::new(element_types), *len) } ast::Type::Integer(Signedness::Signed, bits) => Type::signed((*bits).into()), ast::Type::Integer(Signedness::Unsigned, bits) => Type::unsigned((*bits).into()), ast::Type::Bool => Type::unsigned(1), - ast::Type::String(len) => Type::str(*len as usize), + ast::Type::String(len) => Type::str(*len), ast::Type::FmtString(_, _) => { panic!("convert_non_tuple_type called on a fmt string: {typ}") } @@ -733,9 +733,6 @@ impl<'a> FunctionContext<'a> { let element_types = Self::convert_type(element_type); values.map_both(element_types, |value, element_type| { let reference = value.eval_reference(); - // Reference counting in brillig relies on us incrementing reference - // counts when arrays/slices are constructed or indexed. - // Thus, if we dereference an lvalue which happens to be array/slice we should increment its reference counter. self.builder.insert_load(reference, element_type).into() }) } diff --git a/compiler/noirc_frontend/Cargo.toml b/compiler/noirc_frontend/Cargo.toml index 581d7f1b61d..5d1520af54f 100644 --- a/compiler/noirc_frontend/Cargo.toml +++ b/compiler/noirc_frontend/Cargo.toml @@ -30,8 +30,8 @@ cfg-if.workspace = true tracing.workspace = true petgraph = "0.6" rangemap = "1.4.0" -strum = "0.24" -strum_macros = "0.24" +strum.workspace = true +strum_macros.workspace = true [dev-dependencies] diff --git a/compiler/noirc_frontend/src/debug/mod.rs b/compiler/noirc_frontend/src/debug/mod.rs index fed3149118b..f05fc721581 100644 --- a/compiler/noirc_frontend/src/debug/mod.rs +++ b/compiler/noirc_frontend/src/debug/mod.rs @@ -67,12 +67,16 @@ impl DebugInstrumenter { self.insert_state_set_oracle(module, 8); } - fn insert_var(&mut self, var_name: &str) -> SourceVarId { + fn insert_var(&mut self, var_name: &str) -> Option { + if var_name == "_" { + return None; + } + let var_id = SourceVarId(self.next_var_id); self.next_var_id += 1; self.variables.insert(var_id, var_name.to_string()); self.scope.last_mut().unwrap().insert(var_name.to_string(), var_id); - var_id + Some(var_id) } fn lookup_var(&self, var_name: &str) -> Option { @@ -107,9 +111,9 @@ impl DebugInstrumenter { .flat_map(|param| { pattern_vars(¶m.pattern) .iter() - .map(|(id, _is_mut)| { - let var_id = self.insert_var(&id.0.contents); - build_assign_var_stmt(var_id, id_expr(id)) + .filter_map(|(id, _is_mut)| { + let var_id = self.insert_var(&id.0.contents)?; + Some(build_assign_var_stmt(var_id, id_expr(id))) }) .collect::>() }) @@ -225,13 +229,28 @@ impl DebugInstrumenter { } }) .collect(); - let vars_exprs: Vec = vars.iter().map(|(id, _)| id_expr(id)).collect(); + let vars_exprs: Vec = vars + .iter() + .map(|(id, _)| { + // We don't want to generate an expression to read from "_". + // And since this expression is going to be assigned to "_" so it doesn't matter + // what it is, we can use `()` for it. + if id.0.contents == "_" { + ast::Expression { + kind: ast::ExpressionKind::Literal(ast::Literal::Unit), + span: id.span(), + } + } else { + id_expr(id) + } + }) + .collect(); let mut block_stmts = vec![ast::Statement { kind: ast::StatementKind::Let(let_stmt.clone()), span: *span }]; - block_stmts.extend(vars.iter().map(|(id, _)| { - let var_id = self.insert_var(&id.0.contents); - build_assign_var_stmt(var_id, id_expr(id)) + block_stmts.extend(vars.iter().filter_map(|(id, _)| { + let var_id = self.insert_var(&id.0.contents)?; + Some(build_assign_var_stmt(var_id, id_expr(id))) })); block_stmts.push(ast::Statement { kind: ast::StatementKind::Expression(ast::Expression { @@ -422,21 +441,31 @@ impl DebugInstrumenter { let var_name = &for_stmt.identifier.0.contents; let var_id = self.insert_var(var_name); - let set_stmt = build_assign_var_stmt(var_id, id_expr(&for_stmt.identifier)); - let drop_stmt = build_drop_var_stmt(var_id, Span::empty(for_stmt.span.end())); + let set_and_drop_stmt = var_id.map(|var_id| { + ( + build_assign_var_stmt(var_id, id_expr(&for_stmt.identifier)), + build_drop_var_stmt(var_id, Span::empty(for_stmt.span.end())), + ) + }); self.walk_expr(&mut for_stmt.block); + + let mut statements = Vec::new(); + let block_statement = ast::Statement { + kind: ast::StatementKind::Semi(for_stmt.block.clone()), + span: for_stmt.block.span, + }; + + if let Some((set_stmt, drop_stmt)) = set_and_drop_stmt { + statements.push(set_stmt); + statements.push(block_statement); + statements.push(drop_stmt); + } else { + statements.push(block_statement); + } + for_stmt.block = ast::Expression { - kind: ast::ExpressionKind::Block(ast::BlockExpression { - statements: vec![ - set_stmt, - ast::Statement { - kind: ast::StatementKind::Semi(for_stmt.block.clone()), - span: for_stmt.block.span, - }, - drop_stmt, - ], - }), + kind: ast::ExpressionKind::Block(ast::BlockExpression { statements }), span: for_stmt.span, }; } diff --git a/compiler/noirc_frontend/src/elaborator/comptime.rs b/compiler/noirc_frontend/src/elaborator/comptime.rs index a27e2bf0163..962356d6dd9 100644 --- a/compiler/noirc_frontend/src/elaborator/comptime.rs +++ b/compiler/noirc_frontend/src/elaborator/comptime.rs @@ -329,8 +329,6 @@ impl<'context> Elaborator<'context> { push_arg(Value::TraitDefinition(trait_id)); } else { let (expr_id, expr_type) = interpreter.elaborator.elaborate_expression(arg); - push_arg(interpreter.evaluate(expr_id)?); - if let Err(UnificationError) = expr_type.unify(param_type) { return Err(InterpreterError::TypeMismatch { expected: param_type.clone(), @@ -338,6 +336,7 @@ impl<'context> Elaborator<'context> { location: arg_location, }); } + push_arg(interpreter.evaluate(expr_id)?); }; } diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 8b8130a7653..d0f8043cd95 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -440,6 +440,9 @@ impl<'context> Elaborator<'context> { // so we need to reintroduce the same IDs into scope here. for parameter in &func_meta.parameter_idents { let name = self.interner.definition_name(parameter.id).to_owned(); + if name == "_" { + continue; + } let warn_if_unused = !(func_meta.trait_impl.is_some() && name == "self"); self.add_existing_variable_to_scope(name, parameter.clone(), warn_if_unused); } diff --git a/compiler/noirc_frontend/src/elaborator/patterns.rs b/compiler/noirc_frontend/src/elaborator/patterns.rs index 3928362db11..3fbdadbbee8 100644 --- a/compiler/noirc_frontend/src/elaborator/patterns.rs +++ b/compiler/noirc_frontend/src/elaborator/patterns.rs @@ -331,16 +331,18 @@ impl<'context> Elaborator<'context> { let resolver_meta = ResolverMeta { num_times_used: 0, ident: ident.clone(), warn_if_unused }; - let scope = self.scopes.get_mut_scope(); - let old_value = scope.add_key_value(name.clone(), resolver_meta); - - if !allow_shadowing { - if let Some(old_value) = old_value { - self.push_err(ResolverError::DuplicateDefinition { - name, - first_span: old_value.ident.location.span, - second_span: location.span, - }); + if name != "_" { + let scope = self.scopes.get_mut_scope(); + let old_value = scope.add_key_value(name.clone(), resolver_meta); + + if !allow_shadowing { + if let Some(old_value) = old_value { + self.push_err(ResolverError::DuplicateDefinition { + name, + first_span: old_value.ident.location.span, + second_span: location.span, + }); + } } } diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 4e021921937..ffa24790791 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -1347,11 +1347,23 @@ impl<'context> Elaborator<'context> { { Some(method_id) => Some(HirMethodReference::FuncId(method_id)), None => { - self.push_err(TypeCheckError::UnresolvedMethodCall { - method_name: method_name.to_string(), - object_type: object_type.clone(), - span, - }); + let has_field_with_function_type = + typ.borrow().get_fields_as_written().into_iter().any(|field| { + field.name.0.contents == method_name && field.typ.is_function() + }); + if has_field_with_function_type { + self.push_err(TypeCheckError::CannotInvokeStructFieldFunctionType { + method_name: method_name.to_string(), + object_type: object_type.clone(), + span, + }); + } else { + self.push_err(TypeCheckError::UnresolvedMethodCall { + method_name: method_name.to_string(), + object_type: object_type.clone(), + span, + }); + } None } } diff --git a/compiler/noirc_frontend/src/hir/comptime/errors.rs b/compiler/noirc_frontend/src/hir/comptime/errors.rs index 198ba91156e..446c4dae2d3 100644 --- a/compiler/noirc_frontend/src/hir/comptime/errors.rs +++ b/compiler/noirc_frontend/src/hir/comptime/errors.rs @@ -200,6 +200,11 @@ pub enum InterpreterError { item: String, location: Location, }, + InvalidInComptimeContext { + item: String, + location: Location, + explanation: String, + }, TypeAnnotationsNeededForMethodCall { location: Location, }, @@ -291,6 +296,7 @@ impl InterpreterError { | InterpreterError::UnsupportedTopLevelItemUnquote { location, .. } | InterpreterError::ComptimeDependencyCycle { location, .. } | InterpreterError::Unimplemented { location, .. } + | InterpreterError::InvalidInComptimeContext { location, .. } | InterpreterError::NoImpl { location, .. } | InterpreterError::ImplMethodTypeMismatch { location, .. } | InterpreterError::DebugEvaluateComptime { location, .. } @@ -540,6 +546,10 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic { let msg = format!("{item} is currently unimplemented"); CustomDiagnostic::simple_error(msg, String::new(), location.span) } + InterpreterError::InvalidInComptimeContext { item, location, explanation } => { + let msg = format!("{item} is invalid in comptime context"); + CustomDiagnostic::simple_error(msg, explanation.clone(), location.span) + } InterpreterError::BreakNotInLoop { location } => { let msg = "There is no loop to break out of!".into(); CustomDiagnostic::simple_error(msg, String::new(), location.span) diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index 994318a371a..49fd86b73bb 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -1,6 +1,7 @@ use std::collections::VecDeque; use std::{collections::hash_map::Entry, rc::Rc}; +use acvm::blackbox_solver::BigIntSolverWithId; use acvm::{acir::AcirField, FieldElement}; use fm::FileId; use im::Vector; @@ -62,6 +63,9 @@ pub struct Interpreter<'local, 'interner> { /// multiple times. Without this map, when one of these inner functions exits we would /// unbind the generic completely instead of resetting it to its previous binding. bound_generics: Vec>, + + /// Stateful bigint calculator. + bigint_solver: BigIntSolverWithId, } #[allow(unused)] @@ -71,9 +75,14 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { crate_id: CrateId, current_function: Option, ) -> Self { - let bound_generics = Vec::new(); - let in_loop = false; - Self { elaborator, crate_id, current_function, bound_generics, in_loop } + Self { + elaborator, + crate_id, + current_function, + bound_generics: Vec::new(), + in_loop: false, + bigint_solver: BigIntSolverWithId::default(), + } } pub(crate) fn call_function( @@ -227,11 +236,9 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { .expect("all builtin functions must contain a function attribute which contains the opcode which it links to"); if let Some(builtin) = func_attrs.builtin() { - let builtin = builtin.clone(); - self.call_builtin(&builtin, arguments, return_type, location) + self.call_builtin(builtin.clone().as_str(), arguments, return_type, location) } else if let Some(foreign) = func_attrs.foreign() { - let foreign = foreign.clone(); - foreign::call_foreign(self.elaborator.interner, &foreign, arguments, location) + self.call_foreign(foreign.clone().as_str(), arguments, return_type, location) } else if let Some(oracle) = func_attrs.oracle() { if oracle == "print" { self.print_oracle(arguments) @@ -906,6 +913,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { } } + #[allow(clippy::bool_comparison)] fn evaluate_infix(&mut self, infix: HirInfixExpression, id: ExprId) -> IResult { let lhs_value = self.evaluate(infix.lhs)?; let rhs_value = self.evaluate(infix.rhs)?; @@ -924,310 +932,183 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { InterpreterError::InvalidValuesForBinary { lhs, rhs, location, operator } }; - use InterpreterError::InvalidValuesForBinary; - match infix.operator.kind { - BinaryOpKind::Add => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs + rhs)), - (Value::I8(lhs), Value::I8(rhs)) => { - Ok(Value::I8(lhs.checked_add(rhs).ok_or(error("+"))?)) - } - (Value::I16(lhs), Value::I16(rhs)) => { - Ok(Value::I16(lhs.checked_add(rhs).ok_or(error("+"))?)) - } - (Value::I32(lhs), Value::I32(rhs)) => { - Ok(Value::I32(lhs.checked_add(rhs).ok_or(error("+"))?)) + /// Generate matches that can promote the type of one side to the other if they are compatible. + macro_rules! match_values { + (($lhs_value:ident as $lhs:ident $op:literal $rhs_value:ident as $rhs:ident) { + $( + ($lhs_var:ident, $rhs_var:ident) to $res_var:ident => $expr:expr + ),* + $(,)? + } + ) => { + match ($lhs_value, $rhs_value) { + $( + (Value::$lhs_var($lhs), Value::$rhs_var($rhs)) => { + Ok(Value::$res_var(($expr).ok_or(error($op))?)) + }, + )* + (lhs, rhs) => { + Err(error($op)) + }, + } + }; + } + + /// Generate matches for arithmetic operations on `Field` and integers. + macro_rules! match_arithmetic { + (($lhs_value:ident as $lhs:ident $op:literal $rhs_value:ident as $rhs:ident) { field: $field_expr:expr, int: $int_expr:expr, }) => { + match_values! { + ($lhs_value as $lhs $op $rhs_value as $rhs) { + (Field, Field) to Field => Some($field_expr), + (I8, I8) to I8 => $int_expr, + (I16, I16) to I16 => $int_expr, + (I32, I32) to I32 => $int_expr, + (I64, I64) to I64 => $int_expr, + (U8, U8) to U8 => $int_expr, + (U16, U16) to U16 => $int_expr, + (U32, U32) to U32 => $int_expr, + (U64, U64) to U64 => $int_expr, + } } - (Value::I64(lhs), Value::I64(rhs)) => { - Ok(Value::I64(lhs.checked_add(rhs).ok_or(error("+"))?)) + }; + } + + /// Generate matches for comparison operations on all types, returning `Bool`. + macro_rules! match_cmp { + (($lhs_value:ident as $lhs:ident $op:literal $rhs_value:ident as $rhs:ident) => $expr:expr) => { + match_values! { + ($lhs_value as $lhs $op $rhs_value as $rhs) { + (Field, Field) to Bool => Some($expr), + (Bool, Bool) to Bool => Some($expr), + (I8, I8) to Bool => Some($expr), + (I16, I16) to Bool => Some($expr), + (I32, I32) to Bool => Some($expr), + (I64, I64) to Bool => Some($expr), + (U8, U8) to Bool => Some($expr), + (U16, U16) to Bool => Some($expr), + (U32, U32) to Bool => Some($expr), + (U64, U64) to Bool => Some($expr), + } } - (Value::U8(lhs), Value::U8(rhs)) => { - Ok(Value::U8(lhs.checked_add(rhs).ok_or(error("+"))?)) + }; + } + + /// Generate matches for bitwise operations on `Bool` and integers. + macro_rules! match_bitwise { + (($lhs_value:ident as $lhs:ident $op:literal $rhs_value:ident as $rhs:ident) => $expr:expr) => { + match_values! { + ($lhs_value as $lhs $op $rhs_value as $rhs) { + (Bool, Bool) to Bool => Some($expr), + (I8, I8) to I8 => Some($expr), + (I16, I16) to I16 => Some($expr), + (I32, I32) to I32 => Some($expr), + (I64, I64) to I64 => Some($expr), + (U8, U8) to U8 => Some($expr), + (U16, U16) to U16 => Some($expr), + (U32, U32) to U32 => Some($expr), + (U64, U64) to U64 => Some($expr), + } } - (Value::U16(lhs), Value::U16(rhs)) => { - Ok(Value::U16(lhs.checked_add(rhs).ok_or(error("+"))?)) + }; + } + + /// Generate matches for operations on just integer values. + macro_rules! match_integer { + (($lhs_value:ident as $lhs:ident $op:literal $rhs_value:ident as $rhs:ident) => $expr:expr) => { + match_values! { + ($lhs_value as $lhs $op $rhs_value as $rhs) { + (I8, I8) to I8 => $expr, + (I16, I16) to I16 => $expr, + (I32, I32) to I32 => $expr, + (I64, I64) to I64 => $expr, + (U8, U8) to U8 => $expr, + (U16, U16) to U16 => $expr, + (U32, U32) to U32 => $expr, + (U64, U64) to U64 => $expr, + } } - (Value::U32(lhs), Value::U32(rhs)) => { - Ok(Value::U32(lhs.checked_add(rhs).ok_or(error("+"))?)) + }; + } + + /// Generate matches for bit shifting, which in Noir only accepts `u8` for RHS. + macro_rules! match_bitshift { + (($lhs_value:ident as $lhs:ident $op:literal $rhs_value:ident as $rhs:ident) => $expr:expr) => { + match_values! { + ($lhs_value as $lhs $op $rhs_value as $rhs) { + (I8, U8) to I8 => $expr, + (I16, U8) to I16 => $expr, + (I32, U8) to I32 => $expr, + (I64, U8) to I64 => $expr, + (U8, U8) to U8 => $expr, + (U16, U8) to U16 => $expr, + (U32, U8) to U32 => $expr, + (U64, U8) to U64 => $expr, + } } - (Value::U64(lhs), Value::U64(rhs)) => { - Ok(Value::U64(lhs.checked_add(rhs).ok_or(error("+"))?)) + }; + } + + use InterpreterError::InvalidValuesForBinary; + match infix.operator.kind { + BinaryOpKind::Add => match_arithmetic! { + (lhs_value as lhs "+" rhs_value as rhs) { + field: lhs + rhs, + int: lhs.checked_add(rhs), } - (lhs, rhs) => Err(error("+")), }, - BinaryOpKind::Subtract => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs - rhs)), - (Value::I8(lhs), Value::I8(rhs)) => { - Ok(Value::I8(lhs.checked_sub(rhs).ok_or(error("-"))?)) - } - (Value::I16(lhs), Value::I16(rhs)) => { - Ok(Value::I16(lhs.checked_sub(rhs).ok_or(error("-"))?)) - } - (Value::I32(lhs), Value::I32(rhs)) => { - Ok(Value::I32(lhs.checked_sub(rhs).ok_or(error("-"))?)) - } - (Value::I64(lhs), Value::I64(rhs)) => { - Ok(Value::I64(lhs.checked_sub(rhs).ok_or(error("-"))?)) - } - (Value::U8(lhs), Value::U8(rhs)) => { - Ok(Value::U8(lhs.checked_sub(rhs).ok_or(error("-"))?)) - } - (Value::U16(lhs), Value::U16(rhs)) => { - Ok(Value::U16(lhs.checked_sub(rhs).ok_or(error("-"))?)) + BinaryOpKind::Subtract => match_arithmetic! { + (lhs_value as lhs "-" rhs_value as rhs) { + field: lhs - rhs, + int: lhs.checked_sub(rhs), } - (Value::U32(lhs), Value::U32(rhs)) => { - Ok(Value::U32(lhs.checked_sub(rhs).ok_or(error("-"))?)) - } - (Value::U64(lhs), Value::U64(rhs)) => { - Ok(Value::U64(lhs.checked_sub(rhs).ok_or(error("-"))?)) - } - (lhs, rhs) => Err(error("-")), }, - BinaryOpKind::Multiply => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs * rhs)), - (Value::I8(lhs), Value::I8(rhs)) => { - Ok(Value::I8(lhs.checked_mul(rhs).ok_or(error("*"))?)) - } - (Value::I16(lhs), Value::I16(rhs)) => { - Ok(Value::I16(lhs.checked_mul(rhs).ok_or(error("*"))?)) - } - (Value::I32(lhs), Value::I32(rhs)) => { - Ok(Value::I32(lhs.checked_mul(rhs).ok_or(error("*"))?)) - } - (Value::I64(lhs), Value::I64(rhs)) => { - Ok(Value::I64(lhs.checked_mul(rhs).ok_or(error("*"))?)) - } - (Value::U8(lhs), Value::U8(rhs)) => { - Ok(Value::U8(lhs.checked_mul(rhs).ok_or(error("*"))?)) - } - (Value::U16(lhs), Value::U16(rhs)) => { - Ok(Value::U16(lhs.checked_mul(rhs).ok_or(error("*"))?)) - } - (Value::U32(lhs), Value::U32(rhs)) => { - Ok(Value::U32(lhs.checked_mul(rhs).ok_or(error("*"))?)) + BinaryOpKind::Multiply => match_arithmetic! { + (lhs_value as lhs "*" rhs_value as rhs) { + field: lhs * rhs, + int: lhs.checked_mul(rhs), } - (Value::U64(lhs), Value::U64(rhs)) => { - Ok(Value::U64(lhs.checked_mul(rhs).ok_or(error("*"))?)) - } - (lhs, rhs) => Err(error("*")), }, - BinaryOpKind::Divide => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs / rhs)), - (Value::I8(lhs), Value::I8(rhs)) => { - Ok(Value::I8(lhs.checked_div(rhs).ok_or(error("/"))?)) - } - (Value::I16(lhs), Value::I16(rhs)) => { - Ok(Value::I16(lhs.checked_div(rhs).ok_or(error("/"))?)) - } - (Value::I32(lhs), Value::I32(rhs)) => { - Ok(Value::I32(lhs.checked_div(rhs).ok_or(error("/"))?)) - } - (Value::I64(lhs), Value::I64(rhs)) => { - Ok(Value::I64(lhs.checked_div(rhs).ok_or(error("/"))?)) - } - (Value::U8(lhs), Value::U8(rhs)) => { - Ok(Value::U8(lhs.checked_div(rhs).ok_or(error("/"))?)) - } - (Value::U16(lhs), Value::U16(rhs)) => { - Ok(Value::U16(lhs.checked_div(rhs).ok_or(error("/"))?)) - } - (Value::U32(lhs), Value::U32(rhs)) => { - Ok(Value::U32(lhs.checked_div(rhs).ok_or(error("/"))?)) + BinaryOpKind::Divide => match_arithmetic! { + (lhs_value as lhs "/" rhs_value as rhs) { + field: lhs / rhs, + int: lhs.checked_div(rhs), } - (Value::U64(lhs), Value::U64(rhs)) => { - Ok(Value::U64(lhs.checked_div(rhs).ok_or(error("/"))?)) - } - (lhs, rhs) => Err(error("/")), }, - BinaryOpKind::Equal => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs == rhs)), - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs == rhs)), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::Bool(lhs == rhs)), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs == rhs)), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs == rhs)), - (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs == rhs)), - (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs == rhs)), - (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs == rhs)), - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs == rhs)), - (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs == rhs)), - (lhs, rhs) => Err(error("==")), + BinaryOpKind::Equal => match_cmp! { + (lhs_value as lhs "==" rhs_value as rhs) => lhs == rhs }, - BinaryOpKind::NotEqual => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs != rhs)), - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs != rhs)), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::Bool(lhs != rhs)), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs != rhs)), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs != rhs)), - (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs != rhs)), - (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs != rhs)), - (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs != rhs)), - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs != rhs)), - (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs != rhs)), - (lhs, rhs) => Err(error("!=")), + BinaryOpKind::NotEqual => match_cmp! { + (lhs_value as lhs "!=" rhs_value as rhs) => lhs != rhs }, - BinaryOpKind::Less => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs < rhs)), - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs < rhs)), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::Bool(lhs < rhs)), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs < rhs)), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs < rhs)), - (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs < rhs)), - (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs < rhs)), - (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs < rhs)), - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs < rhs)), - (lhs, rhs) => Err(error("<")), + BinaryOpKind::Less => match_cmp! { + (lhs_value as lhs "<" rhs_value as rhs) => lhs < rhs }, - BinaryOpKind::LessEqual => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (lhs, rhs) => Err(error("<=")), + BinaryOpKind::LessEqual => match_cmp! { + (lhs_value as lhs "<=" rhs_value as rhs) => lhs <= rhs }, - BinaryOpKind::Greater => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs > rhs)), - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs > rhs)), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::Bool(lhs > rhs)), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs > rhs)), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs > rhs)), - (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs > rhs)), - (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs > rhs)), - (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs > rhs)), - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs > rhs)), - (lhs, rhs) => Err(error(">")), + BinaryOpKind::Greater => match_cmp! { + (lhs_value as lhs ">" rhs_value as rhs) => lhs > rhs }, - BinaryOpKind::GreaterEqual => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (lhs, rhs) => Err(error(">=")), + BinaryOpKind::GreaterEqual => match_cmp! { + (lhs_value as lhs ">=" rhs_value as rhs) => lhs >= rhs }, - BinaryOpKind::And => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs & rhs)), - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs & rhs)), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16(lhs & rhs)), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs & rhs)), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs & rhs)), - (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs & rhs)), - (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs & rhs)), - (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs & rhs)), - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs & rhs)), - (lhs, rhs) => Err(error("&")), + BinaryOpKind::And => match_bitwise! { + (lhs_value as lhs "&" rhs_value as rhs) => lhs & rhs }, - BinaryOpKind::Or => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs | rhs)), - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs | rhs)), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16(lhs | rhs)), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs | rhs)), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs | rhs)), - (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs | rhs)), - (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs | rhs)), - (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs | rhs)), - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs | rhs)), - (lhs, rhs) => Err(error("|")), + BinaryOpKind::Or => match_bitwise! { + (lhs_value as lhs "|" rhs_value as rhs) => lhs | rhs }, - BinaryOpKind::Xor => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs ^ rhs)), - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs ^ rhs)), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16(lhs ^ rhs)), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs ^ rhs)), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs ^ rhs)), - (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs ^ rhs)), - (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs ^ rhs)), - (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs ^ rhs)), - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs ^ rhs)), - (lhs, rhs) => Err(error("^")), + BinaryOpKind::Xor => match_bitwise! { + (lhs_value as lhs "^" rhs_value as rhs) => lhs ^ rhs }, - BinaryOpKind::ShiftRight => match (lhs_value.clone(), rhs_value.clone()) { - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8( - lhs.checked_shr(rhs.try_into().map_err(|_| error(">>"))?).ok_or(error(">>"))?, - )), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16( - lhs.checked_shr(rhs.try_into().map_err(|_| error(">>"))?).ok_or(error(">>"))?, - )), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32( - lhs.checked_shr(rhs.try_into().map_err(|_| error(">>"))?).ok_or(error(">>"))?, - )), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64( - lhs.checked_shr(rhs.try_into().map_err(|_| error(">>"))?).ok_or(error(">>"))?, - )), - (Value::U8(lhs), Value::U8(rhs)) => { - Ok(Value::U8(lhs.checked_shr(rhs.into()).ok_or(error(">>"))?)) - } - (Value::U16(lhs), Value::U16(rhs)) => { - Ok(Value::U16(lhs.checked_shr(rhs.into()).ok_or(error(">>"))?)) - } - (Value::U32(lhs), Value::U32(rhs)) => { - Ok(Value::U32(lhs.checked_shr(rhs).ok_or(error(">>"))?)) - } - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64( - lhs.checked_shr(rhs.try_into().map_err(|_| error(">>"))?).ok_or(error(">>"))?, - )), - (lhs, rhs) => Err(error(">>")), + BinaryOpKind::ShiftRight => match_bitshift! { + (lhs_value as lhs ">>" rhs_value as rhs) => lhs.checked_shr(rhs.into()) }, - BinaryOpKind::ShiftLeft => match (lhs_value.clone(), rhs_value.clone()) { - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8( - lhs.checked_shl(rhs.try_into().map_err(|_| error("<<"))?).ok_or(error("<<"))?, - )), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16( - lhs.checked_shl(rhs.try_into().map_err(|_| error("<<"))?).ok_or(error("<<"))?, - )), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32( - lhs.checked_shl(rhs.try_into().map_err(|_| error("<<"))?).ok_or(error("<<"))?, - )), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64( - lhs.checked_shl(rhs.try_into().map_err(|_| error("<<"))?).ok_or(error("<<"))?, - )), - (Value::U8(lhs), Value::U8(rhs)) => { - Ok(Value::U8(lhs.checked_shl(rhs.into()).ok_or(error("<<"))?)) - } - (Value::U16(lhs), Value::U16(rhs)) => { - Ok(Value::U16(lhs.checked_shl(rhs.into()).ok_or(error("<<"))?)) - } - (Value::U32(lhs), Value::U32(rhs)) => { - Ok(Value::U32(lhs.checked_shl(rhs).ok_or(error("<<"))?)) - } - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64( - lhs.checked_shl(rhs.try_into().map_err(|_| error("<<"))?).ok_or(error("<<"))?, - )), - (lhs, rhs) => Err(error("<<")), + BinaryOpKind::ShiftLeft => match_bitshift! { + (lhs_value as lhs "<<" rhs_value as rhs) => lhs.checked_shl(rhs.into()) }, - BinaryOpKind::Modulo => match (lhs_value.clone(), rhs_value.clone()) { - (Value::I8(lhs), Value::I8(rhs)) => { - Ok(Value::I8(lhs.checked_rem(rhs).ok_or(error("%"))?)) - } - (Value::I16(lhs), Value::I16(rhs)) => { - Ok(Value::I16(lhs.checked_rem(rhs).ok_or(error("%"))?)) - } - (Value::I32(lhs), Value::I32(rhs)) => { - Ok(Value::I32(lhs.checked_rem(rhs).ok_or(error("%"))?)) - } - (Value::I64(lhs), Value::I64(rhs)) => { - Ok(Value::I64(lhs.checked_rem(rhs).ok_or(error("%"))?)) - } - (Value::U8(lhs), Value::U8(rhs)) => { - Ok(Value::U8(lhs.checked_rem(rhs).ok_or(error("%"))?)) - } - (Value::U16(lhs), Value::U16(rhs)) => { - Ok(Value::U16(lhs.checked_rem(rhs).ok_or(error("%"))?)) - } - (Value::U32(lhs), Value::U32(rhs)) => { - Ok(Value::U32(lhs.checked_rem(rhs).ok_or(error("%"))?)) - } - (Value::U64(lhs), Value::U64(rhs)) => { - Ok(Value::U64(lhs.checked_rem(rhs).ok_or(error("%"))?)) - } - (lhs, rhs) => Err(error("%")), + BinaryOpKind::Modulo => match_integer! { + (lhs_value as lhs "%" rhs_value as rhs) => lhs.checked_rem(rhs) }, } } diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index 160b7702488..3b9921d9244 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -1,14 +1,14 @@ use std::rc::Rc; -use acvm::{AcirField, FieldElement}; +use acvm::{acir::BlackBoxFunc, AcirField, FieldElement}; use builtin_helpers::{ - block_expression_to_value, check_argument_count, check_function_not_yet_resolved, - check_one_argument, check_three_arguments, check_two_arguments, get_bool, get_expr, get_field, - get_format_string, get_function_def, get_module, get_quoted, get_slice, get_struct, - get_trait_constraint, get_trait_def, get_trait_impl, get_tuple, get_type, get_typed_expr, - get_u32, get_unresolved_type, has_named_attribute, hir_pattern_to_tokens, - mutate_func_meta_type, parse, quote_ident, replace_func_meta_parameters, - replace_func_meta_return_type, + block_expression_to_value, byte_array_type, check_argument_count, + check_function_not_yet_resolved, check_one_argument, check_three_arguments, + check_two_arguments, get_bool, get_expr, get_field, get_format_string, get_function_def, + get_module, get_quoted, get_slice, get_struct, get_trait_constraint, get_trait_def, + get_trait_impl, get_tuple, get_type, get_typed_expr, get_u32, get_unresolved_type, + has_named_attribute, hir_pattern_to_tokens, mutate_func_meta_type, parse, quote_ident, + replace_func_meta_parameters, replace_func_meta_return_type, }; use im::Vector; use iter_extended::{try_vecmap, vecmap}; @@ -42,7 +42,7 @@ use crate::{ }; use self::builtin_helpers::{eq_item, get_array, get_ctstring, get_str, get_u8, hash_item, lex}; -use super::{foreign, Interpreter}; +use super::Interpreter; pub(crate) mod builtin_helpers; @@ -57,7 +57,9 @@ impl<'local, 'context> Interpreter<'local, 'context> { let interner = &mut self.elaborator.interner; let call_stack = &self.elaborator.interpreter_call_stack; match name { - "apply_range_constraint" => foreign::apply_range_constraint(arguments, location), + "apply_range_constraint" => { + self.call_foreign("range", arguments, return_type, location) + } "array_as_str_unchecked" => array_as_str_unchecked(interner, arguments, location), "array_len" => array_len(interner, arguments, location), "array_refcount" => Ok(Value::U32(0)), @@ -234,8 +236,11 @@ impl<'local, 'context> Interpreter<'local, 'context> { "unresolved_type_is_field" => unresolved_type_is_field(interner, arguments, location), "unresolved_type_is_unit" => unresolved_type_is_unit(interner, arguments, location), "zeroed" => zeroed(return_type, location.span), + blackbox if BlackBoxFunc::is_valid_black_box_func_name(blackbox) => { + self.call_foreign(blackbox, arguments, return_type, location) + } _ => { - let item = format!("Comptime evaluation for builtin function {name}"); + let item = format!("Comptime evaluation for builtin function '{name}'"); Err(InterpreterError::Unimplemented { item, location }) } } @@ -324,10 +329,7 @@ fn str_as_bytes( let string = get_str(interner, string)?; let bytes: im::Vector = string.bytes().map(Value::U8).collect(); - let byte_array_type = Type::Array( - Box::new(Type::Constant(bytes.len().into(), Kind::u32())), - Box::new(Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight)), - ); + let byte_array_type = byte_array_type(bytes.len()); Ok(Value::Array(bytes, byte_array_type)) } @@ -820,10 +822,8 @@ fn to_le_radix( Some(digit) => Value::U8(*digit), None => Value::U8(0), }); - Ok(Value::Array( - decomposed_integer.into(), - Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight), - )) + let result_type = byte_array_type(decomposed_integer.len()); + Ok(Value::Array(decomposed_integer.into(), result_type)) } fn compute_to_radix_le(field: FieldElement, radix: u32) -> Vec { diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs index 3f9d92cfe88..cf90aab32e0 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs @@ -2,6 +2,7 @@ use std::hash::Hash; use std::{hash::Hasher, rc::Rc}; use acvm::FieldElement; +use iter_extended::try_vecmap; use noirc_errors::Location; use crate::hir::comptime::display::tokens_to_string; @@ -30,6 +31,8 @@ use crate::{ token::{SecondaryAttribute, Token, Tokens}, QuotedType, Type, }; +use crate::{Kind, Shared, StructType}; +use rustc_hash::FxHashMap as HashMap; pub(crate) fn check_argument_count( expected: usize, @@ -45,38 +48,40 @@ pub(crate) fn check_argument_count( } pub(crate) fn check_one_argument( - mut arguments: Vec<(Value, Location)>, + arguments: Vec<(Value, Location)>, location: Location, ) -> IResult<(Value, Location)> { - check_argument_count(1, &arguments, location)?; + let [arg1] = check_arguments(arguments, location)?; - Ok(arguments.pop().unwrap()) + Ok(arg1) } pub(crate) fn check_two_arguments( - mut arguments: Vec<(Value, Location)>, + arguments: Vec<(Value, Location)>, location: Location, ) -> IResult<((Value, Location), (Value, Location))> { - check_argument_count(2, &arguments, location)?; - - let argument2 = arguments.pop().unwrap(); - let argument1 = arguments.pop().unwrap(); + let [arg1, arg2] = check_arguments(arguments, location)?; - Ok((argument1, argument2)) + Ok((arg1, arg2)) } #[allow(clippy::type_complexity)] pub(crate) fn check_three_arguments( - mut arguments: Vec<(Value, Location)>, + arguments: Vec<(Value, Location)>, location: Location, ) -> IResult<((Value, Location), (Value, Location), (Value, Location))> { - check_argument_count(3, &arguments, location)?; + let [arg1, arg2, arg3] = check_arguments(arguments, location)?; - let argument3 = arguments.pop().unwrap(); - let argument2 = arguments.pop().unwrap(); - let argument1 = arguments.pop().unwrap(); + Ok((arg1, arg2, arg3)) +} - Ok((argument1, argument2, argument3)) +#[allow(clippy::type_complexity)] +pub(crate) fn check_arguments( + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult<[(Value, Location); N]> { + check_argument_count(N, &arguments, location)?; + Ok(arguments.try_into().expect("checked arg count")) } pub(crate) fn get_array( @@ -93,6 +98,47 @@ pub(crate) fn get_array( } } +/// Get the fields if the value is a `Value::Struct`, otherwise report that a struct type +/// with `name` is expected. Returns the `Type` but doesn't verify that it's called `name`. +pub(crate) fn get_struct_fields( + name: &str, + (value, location): (Value, Location), +) -> IResult<(HashMap, Value>, Type)> { + match value { + Value::Struct(fields, typ) => Ok((fields, typ)), + _ => { + let expected = StructType::new( + StructId::dummy_id(), + Ident::new(name.to_string(), location.span), + location, + Vec::new(), + Vec::new(), + ); + let expected = Type::Struct(Shared::new(expected), Vec::new()); + type_mismatch(value, expected, location) + } + } +} + +/// Get a specific field of a struct and apply a decoder function on it. +pub(crate) fn get_struct_field( + field_name: &str, + struct_fields: &HashMap, Value>, + struct_type: &Type, + location: Location, + f: impl Fn((Value, Location)) -> IResult, +) -> IResult { + let key = Rc::new(field_name.to_string()); + let Some(value) = struct_fields.get(&key) else { + return Err(InterpreterError::ExpectedStructToHaveField { + typ: struct_type.clone(), + field_name: Rc::into_inner(key).unwrap(), + location, + }); + }; + f((value.clone(), location)) +} + pub(crate) fn get_bool((value, location): (Value, Location)) -> IResult { match value { Value::Bool(value) => Ok(value), @@ -114,6 +160,49 @@ pub(crate) fn get_slice( } } +/// Interpret the input as a slice, then map each element. +/// Returns the values in the slice and the original type. +pub(crate) fn get_slice_map( + interner: &NodeInterner, + (value, location): (Value, Location), + f: impl Fn((Value, Location)) -> IResult, +) -> IResult<(Vec, Type)> { + let (values, typ) = get_slice(interner, (value, location))?; + let values = try_vecmap(values, |value| f((value, location)))?; + Ok((values, typ)) +} + +/// Interpret the input as an array, then map each element. +/// Returns the values in the array and the original array type. +pub(crate) fn get_array_map( + interner: &NodeInterner, + (value, location): (Value, Location), + f: impl Fn((Value, Location)) -> IResult, +) -> IResult<(Vec, Type)> { + let (values, typ) = get_array(interner, (value, location))?; + let values = try_vecmap(values, |value| f((value, location)))?; + Ok((values, typ)) +} + +/// Get an array and convert it to a fixed size. +/// Returns the values in the array and the original array type. +pub(crate) fn get_fixed_array_map( + interner: &NodeInterner, + (value, location): (Value, Location), + f: impl Fn((Value, Location)) -> IResult, +) -> IResult<([T; N], Type)> { + let (values, typ) = get_array_map(interner, (value, location), f)?; + + values.try_into().map(|v| (v, typ.clone())).map_err(|_| { + // Assuming that `values.len()` corresponds to `typ`. + let Type::Array(_, ref elem) = typ else { + unreachable!("get_array_map checked it was an array") + }; + let expected = Type::Array(Box::new(Type::Constant(N.into(), Kind::u32())), elem.clone()); + InterpreterError::TypeMismatch { expected, actual: typ, location } + }) +} + pub(crate) fn get_str( interner: &NodeInterner, (value, location): (Value, Location), @@ -520,3 +609,44 @@ pub(super) fn eq_item( let other_arg = get_item(other_arg)?; Ok(Value::Bool(self_arg == other_arg)) } + +/// Type to be used in `Value::Array(, )`. +pub(crate) fn byte_array_type(len: usize) -> Type { + Type::Array( + Box::new(Type::Constant(len.into(), Kind::u32())), + Box::new(Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight)), + ) +} + +/// Type to be used in `Value::Slice(, )`. +pub(crate) fn byte_slice_type() -> Type { + Type::Slice(Box::new(Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight))) +} + +/// Create a `Value::Array` from bytes. +pub(crate) fn to_byte_array(values: &[u8]) -> Value { + Value::Array(values.iter().copied().map(Value::U8).collect(), byte_array_type(values.len())) +} + +/// Create a `Value::Slice` from bytes. +pub(crate) fn to_byte_slice(values: &[u8]) -> Value { + Value::Slice(values.iter().copied().map(Value::U8).collect(), byte_slice_type()) +} + +/// Create a `Value::Array` from fields. +pub(crate) fn to_field_array(values: &[FieldElement]) -> Value { + let typ = Type::Array( + Box::new(Type::Constant(values.len().into(), Kind::u32())), + Box::new(Type::FieldElement), + ); + Value::Array(values.iter().copied().map(Value::Field).collect(), typ) +} + +/// Create a `Value::Struct` from fields and the expected return type. +pub(crate) fn to_struct( + fields: impl IntoIterator, + typ: Type, +) -> Value { + let fields = fields.into_iter().map(|(k, v)| (Rc::new(k.to_string()), v)).collect(); + Value::Struct(fields, typ) +} diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/foreign.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/foreign.rs index 3de72969cab..d2611f72535 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/foreign.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/foreign.rs @@ -1,40 +1,126 @@ use acvm::{ - acir::BlackBoxFunc, blackbox_solver::BlackBoxFunctionSolver, AcirField, BlackBoxResolutionError, + acir::BlackBoxFunc, + blackbox_solver::{BigIntSolverWithId, BlackBoxFunctionSolver}, + AcirField, BlackBoxResolutionError, FieldElement, }; -use bn254_blackbox_solver::Bn254BlackBoxSolver; +use bn254_blackbox_solver::Bn254BlackBoxSolver; // Currently locked to only bn254! use im::Vector; -use iter_extended::try_vecmap; use noirc_errors::Location; use crate::{ - hir::comptime::{errors::IResult, InterpreterError, Value}, + hir::comptime::{ + errors::IResult, interpreter::builtin::builtin_helpers::to_byte_array, InterpreterError, + Value, + }, node_interner::NodeInterner, + Type, }; -use super::builtin::builtin_helpers::{ - check_one_argument, check_two_arguments, get_array, get_field, get_u32, get_u64, +use super::{ + builtin::builtin_helpers::{ + check_arguments, check_one_argument, check_three_arguments, check_two_arguments, + get_array_map, get_bool, get_field, get_fixed_array_map, get_slice_map, get_struct_field, + get_struct_fields, get_u32, get_u64, get_u8, to_byte_slice, to_field_array, to_struct, + }, + Interpreter, }; -pub(super) fn call_foreign( +impl<'local, 'context> Interpreter<'local, 'context> { + pub(super) fn call_foreign( + &mut self, + name: &str, + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, + ) -> IResult { + call_foreign( + self.elaborator.interner, + &mut self.bigint_solver, + name, + arguments, + return_type, + location, + ) + } +} + +// Similar to `evaluate_black_box` in `brillig_vm`. +fn call_foreign( interner: &mut NodeInterner, + bigint_solver: &mut BigIntSolverWithId, name: &str, - arguments: Vec<(Value, Location)>, + args: Vec<(Value, Location)>, + return_type: Type, location: Location, ) -> IResult { + use BlackBoxFunc::*; + match name { - "poseidon2_permutation" => poseidon2_permutation(interner, arguments, location), - "keccakf1600" => keccakf1600(interner, arguments, location), + "aes128_encrypt" => aes128_encrypt(interner, args, location), + "bigint_from_le_bytes" => { + bigint_from_le_bytes(interner, bigint_solver, args, return_type, location) + } + "bigint_to_le_bytes" => bigint_to_le_bytes(bigint_solver, args, location), + "bigint_add" => bigint_op(bigint_solver, BigIntAdd, args, return_type, location), + "bigint_sub" => bigint_op(bigint_solver, BigIntSub, args, return_type, location), + "bigint_mul" => bigint_op(bigint_solver, BigIntMul, args, return_type, location), + "bigint_div" => bigint_op(bigint_solver, BigIntDiv, args, return_type, location), + "blake2s" => blake_hash(interner, args, location, acvm::blackbox_solver::blake2s), + "blake3" => blake_hash(interner, args, location, acvm::blackbox_solver::blake3), + "ecdsa_secp256k1" => ecdsa_secp256_verify( + interner, + args, + location, + acvm::blackbox_solver::ecdsa_secp256k1_verify, + ), + "ecdsa_secp256r1" => ecdsa_secp256_verify( + interner, + args, + location, + acvm::blackbox_solver::ecdsa_secp256r1_verify, + ), + "embedded_curve_add" => embedded_curve_add(args, location), + "multi_scalar_mul" => multi_scalar_mul(interner, args, location), + "poseidon2_permutation" => poseidon2_permutation(interner, args, location), + "keccakf1600" => keccakf1600(interner, args, location), + "range" => apply_range_constraint(args, location), + "sha256_compression" => sha256_compression(interner, args, location), _ => { - let item = format!("Comptime evaluation for builtin function {name}"); - Err(InterpreterError::Unimplemented { item, location }) + let explanation = match name { + "schnorr_verify" => "Schnorr verification will be removed.".into(), + "and" | "xor" => "It should be turned into a binary operation.".into(), + "recursive_aggregation" => "A proof cannot be verified at comptime.".into(), + _ => { + let item = format!("Comptime evaluation for foreign function '{name}'"); + return Err(InterpreterError::Unimplemented { item, location }); + } + }; + + let item = format!("Attempting to evaluate foreign function '{name}'"); + Err(InterpreterError::InvalidInComptimeContext { item, location, explanation }) } } } -pub(super) fn apply_range_constraint( +/// `pub fn aes128_encrypt(input: [u8; N], iv: [u8; 16], key: [u8; 16]) -> [u8]` +fn aes128_encrypt( + interner: &mut NodeInterner, arguments: Vec<(Value, Location)>, location: Location, ) -> IResult { + let (inputs, iv, key) = check_three_arguments(arguments, location)?; + + let (inputs, _) = get_array_map(interner, inputs, get_u8)?; + let (iv, _) = get_fixed_array_map(interner, iv, get_u8)?; + let (key, _) = get_fixed_array_map(interner, key, get_u8)?; + + let output = acvm::blackbox_solver::aes128_encrypt(&inputs, iv, key) + .map_err(|e| InterpreterError::BlackBoxError(e, location))?; + + Ok(to_byte_slice(&output)) +} + +fn apply_range_constraint(arguments: Vec<(Value, Location)>, location: Location) -> IResult { let (value, num_bits) = check_two_arguments(arguments, location)?; let input = get_field(value)?; @@ -53,21 +139,192 @@ pub(super) fn apply_range_constraint( } } -// poseidon2_permutation(_input: [Field; N], _state_length: u32) -> [Field; N] +/// `fn from_le_bytes(bytes: [u8], modulus: [u8]) -> BigInt` +/// +/// Returns the ID of the new bigint allocated by the solver. +fn bigint_from_le_bytes( + interner: &mut NodeInterner, + solver: &mut BigIntSolverWithId, + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, +) -> IResult { + let (bytes, modulus) = check_two_arguments(arguments, location)?; + + let (bytes, _) = get_slice_map(interner, bytes, get_u8)?; + let (modulus, _) = get_slice_map(interner, modulus, get_u8)?; + + let id = solver + .bigint_from_bytes(&bytes, &modulus) + .map_err(|e| InterpreterError::BlackBoxError(e, location))?; + + Ok(to_bigint(id, return_type)) +} + +/// `fn to_le_bytes(self) -> [u8; 32]` +/// +/// Take the ID of a bigint and returned its content. +fn bigint_to_le_bytes( + solver: &mut BigIntSolverWithId, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { + let int = check_one_argument(arguments, location)?; + let id = get_bigint_id(int)?; + + let mut bytes = + solver.bigint_to_bytes(id).map_err(|e| InterpreterError::BlackBoxError(e, location))?; + + assert!(bytes.len() <= 32); + bytes.resize(32, 0); + + Ok(to_byte_array(&bytes)) +} + +/// `fn bigint_add(self, other: BigInt) -> BigInt` +/// +/// Takes two previous allocated IDs, gets the values from the solver, +/// stores the result of the operation, returns the new ID. +fn bigint_op( + solver: &mut BigIntSolverWithId, + func: BlackBoxFunc, + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, +) -> IResult { + let (lhs, rhs) = check_two_arguments(arguments, location)?; + + let lhs = get_bigint_id(lhs)?; + let rhs = get_bigint_id(rhs)?; + + let id = solver + .bigint_op(lhs, rhs, func) + .map_err(|e| InterpreterError::BlackBoxError(e, location))?; + + Ok(to_bigint(id, return_type)) +} + +/// Run one of the Blake hash functions. +/// ```text +/// pub fn blake2s(input: [u8; N]) -> [u8; 32] +/// pub fn blake3(input: [u8; N]) -> [u8; 32] +/// ``` +fn blake_hash( + interner: &mut NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, + f: impl Fn(&[u8]) -> Result<[u8; 32], BlackBoxResolutionError>, +) -> IResult { + let inputs = check_one_argument(arguments, location)?; + + let (inputs, _) = get_array_map(interner, inputs, get_u8)?; + let output = f(&inputs).map_err(|e| InterpreterError::BlackBoxError(e, location))?; + + Ok(to_byte_array(&output)) +} + +/// Run one of the Secp256 signature verifications. +/// ```text +/// pub fn verify_signature( +/// public_key_x: [u8; 32], +/// public_key_y: [u8; 32], +/// signature: [u8; 64], +/// message_hash: [u8; N], +/// ) -> bool + +/// pub fn verify_signature_slice( +/// public_key_x: [u8; 32], +/// public_key_y: [u8; 32], +/// signature: [u8; 64], +/// message_hash: [u8], +/// ) -> bool +/// ``` +fn ecdsa_secp256_verify( + interner: &mut NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, + f: impl Fn(&[u8], &[u8; 32], &[u8; 32], &[u8; 64]) -> Result, +) -> IResult { + let [pub_key_x, pub_key_y, sig, msg_hash] = check_arguments(arguments, location)?; + + let (pub_key_x, _) = get_fixed_array_map(interner, pub_key_x, get_u8)?; + let (pub_key_y, _) = get_fixed_array_map(interner, pub_key_y, get_u8)?; + let (sig, _) = get_fixed_array_map(interner, sig, get_u8)?; + + // Hash can be an array or slice. + let (msg_hash, _) = if matches!(msg_hash.0.get_type().as_ref(), Type::Array(_, _)) { + get_array_map(interner, msg_hash.clone(), get_u8)? + } else { + get_slice_map(interner, msg_hash, get_u8)? + }; + + let is_valid = f(&msg_hash, &pub_key_x, &pub_key_y, &sig) + .map_err(|e| InterpreterError::BlackBoxError(e, location))?; + + Ok(Value::Bool(is_valid)) +} + +/// ```text +/// fn embedded_curve_add( +/// point1: EmbeddedCurvePoint, +/// point2: EmbeddedCurvePoint, +/// ) -> [Field; 3] +/// ``` +fn embedded_curve_add(arguments: Vec<(Value, Location)>, location: Location) -> IResult { + let (point1, point2) = check_two_arguments(arguments, location)?; + + let (p1x, p1y, p1inf) = get_embedded_curve_point(point1)?; + let (p2x, p2y, p2inf) = get_embedded_curve_point(point2)?; + + let (x, y, inf) = Bn254BlackBoxSolver + .ec_add(&p1x, &p1y, &p1inf.into(), &p2x, &p2y, &p2inf.into()) + .map_err(|e| InterpreterError::BlackBoxError(e, location))?; + + Ok(to_field_array(&[x, y, inf])) +} + +/// ```text +/// pub fn multi_scalar_mul( +/// points: [EmbeddedCurvePoint; N], +/// scalars: [EmbeddedCurveScalar; N], +/// ) -> [Field; 3] +/// ``` +fn multi_scalar_mul( + interner: &mut NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { + let (points, scalars) = check_two_arguments(arguments, location)?; + + let (points, _) = get_array_map(interner, points, get_embedded_curve_point)?; + let (scalars, _) = get_array_map(interner, scalars, get_embedded_curve_scalar)?; + + let points: Vec<_> = points.into_iter().flat_map(|(x, y, inf)| [x, y, inf.into()]).collect(); + let mut scalars_lo = Vec::new(); + let mut scalars_hi = Vec::new(); + for (lo, hi) in scalars { + scalars_lo.push(lo); + scalars_hi.push(hi); + } + + let (x, y, inf) = Bn254BlackBoxSolver + .multi_scalar_mul(&points, &scalars_lo, &scalars_hi) + .map_err(|e| InterpreterError::BlackBoxError(e, location))?; + + Ok(to_field_array(&[x, y, inf])) +} + +/// `poseidon2_permutation(_input: [Field; N], _state_length: u32) -> [Field; N]` fn poseidon2_permutation( interner: &mut NodeInterner, arguments: Vec<(Value, Location)>, location: Location, ) -> IResult { let (input, state_length) = check_two_arguments(arguments, location)?; - let input_location = input.1; - let (input, typ) = get_array(interner, input)?; + let (input, typ) = get_array_map(interner, input, get_field)?; let state_length = get_u32(state_length)?; - let input = try_vecmap(input, |integer| get_field((integer, input_location)))?; - - // Currently locked to only bn254! let fields = Bn254BlackBoxSolver .poseidon2_permutation(&input, state_length) .map_err(|error| InterpreterError::BlackBoxError(error, location))?; @@ -76,25 +333,135 @@ fn poseidon2_permutation( Ok(Value::Array(array, typ)) } +/// `fn keccakf1600(input: [u64; 25]) -> [u64; 25] {}` fn keccakf1600( interner: &mut NodeInterner, arguments: Vec<(Value, Location)>, location: Location, ) -> IResult { let input = check_one_argument(arguments, location)?; - let input_location = input.1; - let (input, typ) = get_array(interner, input)?; + let (state, typ) = get_fixed_array_map(interner, input, get_u64)?; - let input = try_vecmap(input, |integer| get_u64((integer, input_location)))?; - - let mut state = [0u64; 25]; - for (it, input_value) in state.iter_mut().zip(input.iter()) { - *it = *input_value; - } let result_lanes = acvm::blackbox_solver::keccakf1600(state) .map_err(|error| InterpreterError::BlackBoxError(error, location))?; let array: Vector = result_lanes.into_iter().map(Value::U64).collect(); Ok(Value::Array(array, typ)) } + +/// `pub fn sha256_compression(input: [u32; 16], state: [u32; 8]) -> [u32; 8]` +fn sha256_compression( + interner: &mut NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { + let (input, state) = check_two_arguments(arguments, location)?; + + let (input, _) = get_fixed_array_map(interner, input, get_u32)?; + let (mut state, typ) = get_fixed_array_map(interner, state, get_u32)?; + + acvm::blackbox_solver::sha256_compression(&mut state, &input); + + let state = state.into_iter().map(Value::U32).collect(); + Ok(Value::Array(state, typ)) +} + +/// Decode a `BigInt` struct. +/// +/// Returns the ID of the value in the solver. +fn get_bigint_id((value, location): (Value, Location)) -> IResult { + let (fields, typ) = get_struct_fields("BigInt", (value, location))?; + let p = get_struct_field("pointer", &fields, &typ, location, get_u32)?; + let m = get_struct_field("modulus", &fields, &typ, location, get_u32)?; + assert_eq!(p, m, "`pointer` and `modulus` are expected to be the same"); + Ok(p) +} + +/// Decode an `EmbeddedCurvePoint` struct. +/// +/// Returns `(x, y, is_infinite)`. +fn get_embedded_curve_point( + (value, location): (Value, Location), +) -> IResult<(FieldElement, FieldElement, bool)> { + let (fields, typ) = get_struct_fields("EmbeddedCurvePoint", (value, location))?; + let x = get_struct_field("x", &fields, &typ, location, get_field)?; + let y = get_struct_field("y", &fields, &typ, location, get_field)?; + let is_infinite = get_struct_field("is_infinite", &fields, &typ, location, get_bool)?; + Ok((x, y, is_infinite)) +} + +/// Decode an `EmbeddedCurveScalar` struct. +/// +/// Returns `(lo, hi)`. +fn get_embedded_curve_scalar( + (value, location): (Value, Location), +) -> IResult<(FieldElement, FieldElement)> { + let (fields, typ) = get_struct_fields("EmbeddedCurveScalar", (value, location))?; + let lo = get_struct_field("lo", &fields, &typ, location, get_field)?; + let hi = get_struct_field("hi", &fields, &typ, location, get_field)?; + Ok((lo, hi)) +} + +fn to_bigint(id: u32, typ: Type) -> Value { + to_struct([("pointer", Value::U32(id)), ("modulus", Value::U32(id))], typ) +} + +#[cfg(test)] +mod tests { + use acvm::acir::BlackBoxFunc; + use noirc_errors::Location; + use strum::IntoEnumIterator; + + use crate::hir::comptime::tests::with_interpreter; + use crate::hir::comptime::InterpreterError::{ + ArgumentCountMismatch, InvalidInComptimeContext, Unimplemented, + }; + use crate::Type; + + use super::call_foreign; + + /// Check that all `BlackBoxFunc` are covered by `call_foreign`. + #[test] + fn test_blackbox_implemented() { + let dummy = " + comptime fn main() -> pub u8 { + 0 + } + "; + + let not_implemented = with_interpreter(dummy, |interpreter, _, _| { + let no_location = Location::dummy(); + let mut not_implemented = Vec::new(); + + for blackbox in BlackBoxFunc::iter() { + let name = blackbox.name(); + match call_foreign( + interpreter.elaborator.interner, + &mut interpreter.bigint_solver, + name, + Vec::new(), + Type::Unit, + no_location, + ) { + Ok(_) => { + // Exists and works with no args (unlikely) + } + Err(ArgumentCountMismatch { .. }) => { + // Exists but doesn't work with no args (expected) + } + Err(InvalidInComptimeContext { .. }) => {} + Err(Unimplemented { .. }) => not_implemented.push(name), + Err(other) => panic!("unexpected error: {other:?}"), + }; + } + + not_implemented + }); + + assert!( + not_implemented.is_empty(), + "unimplemented blackbox functions: {not_implemented:?}" + ); + } +} diff --git a/compiler/noirc_frontend/src/hir/comptime/tests.rs b/compiler/noirc_frontend/src/hir/comptime/tests.rs index e033ec6ddb9..2d3bf928917 100644 --- a/compiler/noirc_frontend/src/hir/comptime/tests.rs +++ b/compiler/noirc_frontend/src/hir/comptime/tests.rs @@ -9,14 +9,20 @@ use noirc_errors::Location; use super::errors::InterpreterError; use super::value::Value; +use super::Interpreter; use crate::elaborator::Elaborator; -use crate::hir::def_collector::dc_crate::DefCollector; +use crate::hir::def_collector::dc_crate::{CompilationError, DefCollector}; use crate::hir::def_collector::dc_mod::collect_defs; use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleData}; use crate::hir::{Context, ParsedFiles}; +use crate::node_interner::FuncId; use crate::parse_program; -fn interpret_helper(src: &str) -> Result { +/// Create an interpreter for a code snippet and pass it to a test function. +pub(crate) fn with_interpreter( + src: &str, + f: impl FnOnce(&mut Interpreter, FuncId, &[(CompilationError, FileId)]) -> T, +) -> T { let file = FileId::default(); // Can't use Index::test_new here for some reason, even with #[cfg(test)]. @@ -51,14 +57,24 @@ fn interpret_helper(src: &str) -> Result { context.def_maps.insert(krate, collector.def_map); let main = context.get_main_function(&krate).expect("Expected 'main' function"); + let mut elaborator = Elaborator::elaborate_and_return_self(&mut context, krate, collector.items, None); - assert_eq!(elaborator.errors.len(), 0); + + let errors = elaborator.errors.clone(); let mut interpreter = elaborator.setup_interpreter(); - let no_location = Location::dummy(); - interpreter.call_function(main, Vec::new(), HashMap::new(), no_location) + f(&mut interpreter, main, &errors) +} + +/// Evaluate a code snippet by calling the `main` function. +fn interpret_helper(src: &str) -> Result { + with_interpreter(src, |interpreter, main, errors| { + assert_eq!(errors.len(), 0); + let no_location = Location::dummy(); + interpreter.call_function(main, Vec::new(), HashMap::new(), no_location) + }) } fn interpret(src: &str) -> Value { diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index bae57daae15..e7953aab5a4 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -216,6 +216,13 @@ impl<'a> ModCollector<'a> { errors.push((error.into(), self.file_id)); } + if noir_function.def.attributes.has_export() { + let error = DefCollectorErrorKind::ExportOnAssociatedFunction { + span: noir_function.name_ident().span(), + }; + errors.push((error.into(), self.file_id)); + } + let location = Location::new(noir_function.def.span, self.file_id); context.def_interner.push_function(*func_id, &noir_function.def, module, location); } @@ -944,6 +951,7 @@ pub fn collect_function( } else { function.name() == MAIN_FUNCTION }; + let has_export = function.def.attributes.has_export(); let name = function.name_ident().clone(); let func_id = interner.push_empty_fn(); @@ -954,7 +962,7 @@ pub fn collect_function( interner.register_function(func_id, &function.def); } - if !is_test && !is_entry_point_function { + if !is_test && !is_entry_point_function && !has_export { let item = UnusedItem::Function(func_id); usage_tracker.add_unused_item(module, name.clone(), item, visibility); } @@ -1087,6 +1095,12 @@ pub fn collect_impl( errors.push((error.into(), file_id)); continue; } + if method.def.attributes.has_export() { + let error = DefCollectorErrorKind::ExportOnAssociatedFunction { + span: method.name_ident().span(), + }; + errors.push((error.into(), file_id)); + } let func_id = interner.push_empty_fn(); method.def.where_clause.extend(r#impl.where_clause.clone()); @@ -1257,6 +1271,7 @@ pub(crate) fn collect_global( // Add the statement to the scope so its path can be looked up later let result = def_map.modules[module_id.0].declare_global(name.clone(), visibility, global_id); + // Globals marked as ABI don't have to be used. if !is_abi { let parent_module_id = ModuleId { krate: crate_id, local_id: module_id }; usage_tracker.add_unused_item( diff --git a/compiler/noirc_frontend/src/hir/def_collector/errors.rs b/compiler/noirc_frontend/src/hir/def_collector/errors.rs index c08b4ff2062..cafbc670e32 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/errors.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/errors.rs @@ -84,6 +84,8 @@ pub enum DefCollectorErrorKind { UnsupportedNumericGenericType(#[from] UnsupportedNumericGenericType), #[error("The `#[test]` attribute may only be used on a non-associated function")] TestOnAssociatedFunction { span: Span }, + #[error("The `#[export]` attribute may only be used on a non-associated function")] + ExportOnAssociatedFunction { span: Span }, } impl DefCollectorErrorKind { @@ -182,8 +184,8 @@ impl<'a> From<&'a DefCollectorErrorKind> for Diagnostic { DefCollectorErrorKind::PathResolutionError(error) => error.into(), DefCollectorErrorKind::CannotReexportItemWithLessVisibility{item_name, desired_visibility} => { Diagnostic::simple_warning( - format!("cannot re-export {item_name} because it has less visibility than this use statement"), - format!("consider marking {item_name} as {desired_visibility}"), + format!("cannot re-export {item_name} because it has less visibility than this use statement"), + format!("consider marking {item_name} as {desired_visibility}"), item_name.span()) } DefCollectorErrorKind::NonStructTypeInImpl { span } => Diagnostic::simple_error( @@ -298,7 +300,11 @@ impl<'a> From<&'a DefCollectorErrorKind> for Diagnostic { String::new(), *span, ), - + DefCollectorErrorKind::ExportOnAssociatedFunction { span } => Diagnostic::simple_error( + "The `#[export]` attribute is disallowed on `impl` methods".into(), + String::new(), + *span, + ), } } } diff --git a/compiler/noirc_frontend/src/hir/def_map/mod.rs b/compiler/noirc_frontend/src/hir/def_map/mod.rs index de94f73b44b..3bb16a92fdb 100644 --- a/compiler/noirc_frontend/src/hir/def_map/mod.rs +++ b/compiler/noirc_frontend/src/hir/def_map/mod.rs @@ -193,11 +193,7 @@ impl CrateDefMap { module.value_definitions().filter_map(|id| { if let Some(func_id) = id.as_function() { let attributes = interner.function_attributes(&func_id); - if attributes.secondary.contains(&SecondaryAttribute::Export) { - Some(func_id) - } else { - None - } + attributes.has_export().then_some(func_id) } else { None } diff --git a/compiler/noirc_frontend/src/hir/resolution/errors.rs b/compiler/noirc_frontend/src/hir/resolution/errors.rs index 80bd5247ee6..5c8e0a1b53e 100644 --- a/compiler/noirc_frontend/src/hir/resolution/errors.rs +++ b/compiler/noirc_frontend/src/hir/resolution/errors.rs @@ -223,11 +223,21 @@ impl<'a> From<&'a ResolverError> for Diagnostic { *span, ) } - ResolverError::VariableNotDeclared { name, span } => Diagnostic::simple_error( - format!("cannot find `{name}` in this scope "), - "not found in this scope".to_string(), - *span, - ), + ResolverError::VariableNotDeclared { name, span } => { + if name == "_" { + Diagnostic::simple_error( + "in expressions, `_` can only be used on the left-hand side of an assignment".to_string(), + "`_` not allowed here".to_string(), + *span, + ) + } else { + Diagnostic::simple_error( + format!("cannot find `{name}` in this scope"), + "not found in this scope".to_string(), + *span, + ) + } + }, ResolverError::PathIsNotIdent { span } => Diagnostic::simple_error( "cannot use path as an identifier".to_string(), String::new(), diff --git a/compiler/noirc_frontend/src/hir/type_check/errors.rs b/compiler/noirc_frontend/src/hir/type_check/errors.rs index f1d8178fa38..c75f6ecfb45 100644 --- a/compiler/noirc_frontend/src/hir/type_check/errors.rs +++ b/compiler/noirc_frontend/src/hir/type_check/errors.rs @@ -101,6 +101,8 @@ pub enum TypeCheckError { CannotMutateImmutableVariable { name: String, span: Span }, #[error("No method named '{method_name}' found for type '{object_type}'")] UnresolvedMethodCall { method_name: String, object_type: Type, span: Span }, + #[error("Cannot invoke function field '{method_name}' on type '{object_type}' as a method")] + CannotInvokeStructFieldFunctionType { method_name: String, object_type: Type, span: Span }, #[error("Integers must have the same signedness LHS is {sign_x:?}, RHS is {sign_y:?}")] IntegerSignedness { sign_x: Signedness, sign_y: Signedness, span: Span }, #[error("Integers must have the same bit width LHS is {bit_width_x}, RHS is {bit_width_y}")] @@ -504,6 +506,13 @@ impl<'a> From<&'a TypeCheckError> for Diagnostic { TypeCheckError::CyclicType { typ: _, span } => { Diagnostic::simple_error(error.to_string(), "Cyclic types have unlimited size and are prohibited in Noir".into(), *span) } + TypeCheckError::CannotInvokeStructFieldFunctionType { method_name, object_type, span } => { + Diagnostic::simple_error( + format!("Cannot invoke function field '{method_name}' on type '{object_type}' as a method"), + format!("to call the function stored in '{method_name}', surround the field access with parentheses: '(', ')'"), + *span, + ) + }, } } } diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 35829845277..d88a783db5b 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -1093,6 +1093,14 @@ impl Type { } } + pub fn is_function(&self) -> bool { + match self.follow_bindings_shallow().as_ref() { + Type::Function(..) => true, + Type::Alias(alias_type, _) => alias_type.borrow().typ.is_function(), + _ => false, + } + } + /// True if this type can be used as a parameter to `main` or a contract function. /// This is only false for unsized types like slices or slices that do not make sense /// as a program input such as named generics or mutable references. diff --git a/compiler/noirc_frontend/src/lexer/token.rs b/compiler/noirc_frontend/src/lexer/token.rs index dbb28cf78c0..836161c7c9f 100644 --- a/compiler/noirc_frontend/src/lexer/token.rs +++ b/compiler/noirc_frontend/src/lexer/token.rs @@ -676,9 +676,7 @@ impl Attributes { /// This is useful for finding out if we should compile a contract method /// as an entry point or not. pub fn has_contract_library_method(&self) -> bool { - self.secondary - .iter() - .any(|attribute| attribute == &SecondaryAttribute::ContractLibraryMethod) + self.has_secondary_attr(&SecondaryAttribute::ContractLibraryMethod) } pub fn is_test_function(&self) -> bool { @@ -718,11 +716,21 @@ impl Attributes { } pub fn has_varargs(&self) -> bool { - self.secondary.iter().any(|attr| matches!(attr, SecondaryAttribute::Varargs)) + self.has_secondary_attr(&SecondaryAttribute::Varargs) } pub fn has_use_callers_scope(&self) -> bool { - self.secondary.iter().any(|attr| matches!(attr, SecondaryAttribute::UseCallersScope)) + self.has_secondary_attr(&SecondaryAttribute::UseCallersScope) + } + + /// True if the function is marked with an `#[export]` attribute. + pub fn has_export(&self) -> bool { + self.has_secondary_attr(&SecondaryAttribute::Export) + } + + /// Check if secondary attributes contain a specific instance. + pub fn has_secondary_attr(&self, attr: &SecondaryAttribute) -> bool { + self.secondary.contains(attr) } } diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 736d37fe83f..6d70ea2fd6d 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -2351,7 +2351,7 @@ impl Methods { } /// Select the 1 matching method with an object type matching `typ` - fn find_matching_method( + pub fn find_matching_method( &self, typ: &Type, has_self_param: bool, diff --git a/compiler/noirc_frontend/src/parser/errors.rs b/compiler/noirc_frontend/src/parser/errors.rs index bcb4ce1c616..899928528e6 100644 --- a/compiler/noirc_frontend/src/parser/errors.rs +++ b/compiler/noirc_frontend/src/parser/errors.rs @@ -19,6 +19,8 @@ pub enum ParserErrorReason { UnexpectedComma, #[error("Expected a `{token}` separating these two {items}")] ExpectedTokenSeparatingTwoItems { token: Token, items: &'static str }, + #[error("Expected `mut` after `&`, found `{found}`")] + ExpectedMutAfterAmpersand { found: Token }, #[error("Invalid left-hand side of assignment")] InvalidLeftHandSideOfAssignment, #[error("Expected trait, found {found}")] @@ -265,6 +267,11 @@ impl<'a> From<&'a ParserError> for Diagnostic { error.span, ), ParserErrorReason::Lexer(error) => error.into(), + ParserErrorReason::ExpectedMutAfterAmpersand { found } => Diagnostic::simple_error( + format!("Expected `mut` after `&`, found `{found}`"), + "Noir doesn't have immutable references, only mutable references".to_string(), + error.span, + ), other => Diagnostic::simple_error(format!("{other}"), String::new(), error.span), }, None => { diff --git a/compiler/noirc_frontend/src/parser/parser.rs b/compiler/noirc_frontend/src/parser/parser.rs index f369839ddd4..c2f7b781873 100644 --- a/compiler/noirc_frontend/src/parser/parser.rs +++ b/compiler/noirc_frontend/src/parser/parser.rs @@ -498,6 +498,13 @@ impl<'a> Parser<'a> { self.push_error(ParserErrorReason::ExpectedTokenSeparatingTwoItems { token, items }, span); } + fn expected_mut_after_ampersand(&mut self) { + self.push_error( + ParserErrorReason::ExpectedMutAfterAmpersand { found: self.token.token().clone() }, + self.current_token_span, + ); + } + fn modifiers_not_followed_by_an_item(&mut self, modifiers: Modifiers) { self.visibility_not_followed_by_an_item(modifiers); self.unconstrained_not_followed_by_an_item(modifiers); diff --git a/compiler/noirc_frontend/src/parser/parser/types.rs b/compiler/noirc_frontend/src/parser/parser/types.rs index be3d5287cab..0de94a89be5 100644 --- a/compiler/noirc_frontend/src/parser/parser/types.rs +++ b/compiler/noirc_frontend/src/parser/parser/types.rs @@ -341,7 +341,10 @@ impl<'a> Parser<'a> { fn parses_mutable_reference_type(&mut self) -> Option { if self.eat(Token::Ampersand) { - self.eat_keyword_or_error(Keyword::Mut); + if !self.eat_keyword(Keyword::Mut) { + self.expected_mut_after_ampersand(); + } + return Some(UnresolvedTypeData::MutableReference(Box::new( self.parse_type_or_error(), ))); diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index d789fbed950..e4796fcab02 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -3740,49 +3740,154 @@ fn allows_struct_with_generic_infix_type_as_main_input_3() { } #[test] -fn disallows_test_attribute_on_impl_method() { +fn errors_with_better_message_when_trying_to_invoke_struct_field_that_is_a_function() { let src = r#" - pub struct Foo {} - impl Foo { - #[test] - fn foo() {} - } + pub struct Foo { + wrapped: fn(Field) -> bool, + } - fn main() {} + impl Foo { + fn call(self) -> bool { + self.wrapped(1) + } + } + + fn main() {} "#; let errors = get_program_errors(src); assert_eq!(errors.len(), 1); - assert!(matches!( - errors[0].0, - CompilationError::DefinitionError(DefCollectorErrorKind::TestOnAssociatedFunction { - span: _ - }) - )); + let CompilationError::TypeError(TypeCheckError::CannotInvokeStructFieldFunctionType { + method_name, + .. + }) = &errors[0].0 + else { + panic!("Expected a 'CannotInvokeStructFieldFunctionType' error, got {:?}", errors[0].0); + }; + + assert_eq!(method_name, "wrapped"); +} + +fn test_disallows_attribute_on_impl_method( + attr: &str, + check_error: impl FnOnce(&CompilationError), +) { + let src = format!( + " + pub struct Foo {{ }} + + impl Foo {{ + #[{attr}] + fn foo() {{ }} + }} + + fn main() {{ }} + " + ); + let errors = get_program_errors(&src); + assert_eq!(errors.len(), 1); + check_error(&errors[0].0); +} + +fn test_disallows_attribute_on_trait_impl_method( + attr: &str, + check_error: impl FnOnce(&CompilationError), +) { + let src = format!( + " + pub trait Trait {{ + fn foo() {{ }} + }} + + pub struct Foo {{ }} + + impl Trait for Foo {{ + #[{attr}] + fn foo() {{ }} + }} + + fn main() {{ }} + " + ); + let errors = get_program_errors(&src); + assert_eq!(errors.len(), 1); + check_error(&errors[0].0); +} + +#[test] +fn disallows_test_attribute_on_impl_method() { + test_disallows_attribute_on_impl_method("test", |error| { + assert!(matches!( + error, + CompilationError::DefinitionError( + DefCollectorErrorKind::TestOnAssociatedFunction { .. } + ) + )); + }); } #[test] fn disallows_test_attribute_on_trait_impl_method() { + test_disallows_attribute_on_trait_impl_method("test", |error| { + assert!(matches!( + error, + CompilationError::DefinitionError( + DefCollectorErrorKind::TestOnAssociatedFunction { .. } + ) + )); + }); +} + +#[test] +fn disallows_export_attribute_on_impl_method() { + test_disallows_attribute_on_impl_method("export", |error| { + assert!(matches!( + error, + CompilationError::DefinitionError( + DefCollectorErrorKind::ExportOnAssociatedFunction { .. } + ) + )); + }); +} + +#[test] +fn disallows_export_attribute_on_trait_impl_method() { + test_disallows_attribute_on_trait_impl_method("export", |error| { + assert!(matches!( + error, + CompilationError::DefinitionError( + DefCollectorErrorKind::ExportOnAssociatedFunction { .. } + ) + )); + }); +} + +#[test] +fn allows_multiple_underscore_parameters() { let src = r#" - pub trait Trait { - fn foo() {} - } + pub fn foo(_: i32, _: i64) {} - pub struct Foo {} - impl Trait for Foo { - #[test] - fn foo() {} - } + fn main() {} + "#; + assert_no_errors(src); +} - fn main() {} +#[test] +fn disallows_underscore_on_right_hand_side() { + let src = r#" + fn main() { + let _ = 1; + let _x = _; + } "#; let errors = get_program_errors(src); assert_eq!(errors.len(), 1); - assert!(matches!( - errors[0].0, - CompilationError::DefinitionError(DefCollectorErrorKind::TestOnAssociatedFunction { - span: _ - }) - )); + let CompilationError::ResolverError(ResolverError::VariableNotDeclared { name, .. }) = + &errors[0].0 + else { + panic!("Expected a VariableNotDeclared error, got {:?}", errors[0].0); + }; + + assert_eq!(name, "_"); } diff --git a/compiler/noirc_frontend/src/tests/metaprogramming.rs b/compiler/noirc_frontend/src/tests/metaprogramming.rs index 82c40203244..89a049ebc9d 100644 --- a/compiler/noirc_frontend/src/tests/metaprogramming.rs +++ b/compiler/noirc_frontend/src/tests/metaprogramming.rs @@ -141,3 +141,23 @@ fn errors_if_macros_inject_functions_with_name_collisions() { ) if contents == "foo" )); } + +#[test] +fn uses_correct_type_for_attribute_arguments() { + let src = r#" + #[foo(32)] + comptime fn foo(_f: FunctionDefinition, i: u32) { + let y: u32 = 1; + let _ = y == i; + } + + #[bar([0; 2])] + comptime fn bar(_f: FunctionDefinition, i: [u32; 2]) { + let y: u32 = 1; + let _ = y == i[0]; + } + + fn main() {} + "#; + assert_no_errors(src); +} diff --git a/compiler/noirc_frontend/src/tests/unused_items.rs b/compiler/noirc_frontend/src/tests/unused_items.rs index 9304209e501..c38e604f2c3 100644 --- a/compiler/noirc_frontend/src/tests/unused_items.rs +++ b/compiler/noirc_frontend/src/tests/unused_items.rs @@ -224,9 +224,31 @@ fn does_not_warn_on_unused_global_if_it_has_an_abi_attribute() { assert_no_errors(src); } +#[test] +fn does_not_warn_on_unused_struct_if_it_has_an_abi_attribute() { + let src = r#" + #[abi(dummy)] + struct Foo { bar: u8 } + + fn main() {} + "#; + assert_no_errors(src); +} + +#[test] +fn does_not_warn_on_unused_function_if_it_has_an_export_attribute() { + let src = r#" + #[export] + fn foo() {} + + fn main() {} + "#; + assert_no_errors(src); +} + #[test] fn no_warning_on_inner_struct_when_parent_is_used() { - let src = r#" + let src = r#" struct Bar { inner: [Field; 3], } @@ -247,7 +269,7 @@ fn no_warning_on_inner_struct_when_parent_is_used() { #[test] fn no_warning_on_struct_if_it_has_an_abi_attribute() { - let src = r#" + let src = r#" #[abi(functions)] struct Foo { a: Field, @@ -260,7 +282,7 @@ fn no_warning_on_struct_if_it_has_an_abi_attribute() { #[test] fn no_warning_on_indirect_struct_if_it_has_an_abi_attribute() { - let src = r#" + let src = r#" struct Bar { field: Field, } @@ -277,7 +299,7 @@ fn no_warning_on_indirect_struct_if_it_has_an_abi_attribute() { #[test] fn no_warning_on_self_in_trait_impl() { - let src = r#" + let src = r#" struct Bar {} trait Foo { @@ -298,18 +320,18 @@ fn no_warning_on_self_in_trait_impl() { #[test] fn resolves_trait_where_clause_in_the_correct_module() { // This is a regression test for https://github.com/noir-lang/noir/issues/6479 - let src = r#" + let src = r#" mod foo { pub trait Foo {} } - + use foo::Foo; - + pub trait Bar where T: Foo, {} - + fn main() {} "#; assert_no_errors(src); diff --git a/compiler/noirc_frontend/src/usage_tracker.rs b/compiler/noirc_frontend/src/usage_tracker.rs index fa87ca6961b..6987358ddb7 100644 --- a/compiler/noirc_frontend/src/usage_tracker.rs +++ b/compiler/noirc_frontend/src/usage_tracker.rs @@ -35,6 +35,8 @@ pub struct UsageTracker { } impl UsageTracker { + /// Register an item as unused, waiting to be marked as used later. + /// Things that should not emit warnings should not be added at all. pub(crate) fn add_unused_item( &mut self, module_id: ModuleId, @@ -73,6 +75,7 @@ impl UsageTracker { }; } + /// Get all the unused items per module. pub fn unused_items(&self) -> &HashMap> { &self.unused_items } diff --git a/compiler/noirc_printable_type/src/lib.rs b/compiler/noirc_printable_type/src/lib.rs index 5ab04c6f576..838a2472125 100644 --- a/compiler/noirc_printable_type/src/lib.rs +++ b/compiler/noirc_printable_type/src/lib.rs @@ -69,6 +69,9 @@ pub enum PrintableValueDisplay { #[derive(Debug, Error)] pub enum ForeignCallError { + #[error("No handler could be found for foreign call `{0}`")] + NoHandler(String), + #[error("Foreign call inputs needed for execution are missing")] MissingForeignCallInputs, diff --git a/compiler/wasm/Cargo.toml b/compiler/wasm/Cargo.toml index c8b8c3bb06e..9951b23f609 100644 --- a/compiler/wasm/Cargo.toml +++ b/compiler/wasm/Cargo.toml @@ -1,10 +1,12 @@ [package] name = "noir_wasm" +description = "A JS interface to the Noir compiler" version.workspace = true authors.workspace = true edition.workspace = true rust-version.workspace = true license.workspace = true +repository.workspace = true [lints] workspace = true @@ -42,4 +44,4 @@ getrandom = { workspace = true, features = ["js"] } rust-embed = { workspace = true, features = ["debug-embed"] } [build-dependencies] -build-data.workspace = true \ No newline at end of file +build-data.workspace = true diff --git a/compiler/wasm/LICENSE-APACHE b/compiler/wasm/LICENSE-APACHE new file mode 100644 index 00000000000..261eeb9e9f8 --- /dev/null +++ b/compiler/wasm/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/compiler/wasm/LICENSE-MIT b/compiler/wasm/LICENSE-MIT new file mode 100644 index 00000000000..a93d7f55c8e --- /dev/null +++ b/compiler/wasm/LICENSE-MIT @@ -0,0 +1,21 @@ +MIT License + + Copyright (c) 2021-2023 noir-lang + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/compiler/wasm/tsconfig.json b/compiler/wasm/tsconfig.json index d2ae58b8fc9..42c7396aa83 100644 --- a/compiler/wasm/tsconfig.json +++ b/compiler/wasm/tsconfig.json @@ -18,4 +18,4 @@ "allowJs": true, }, "exclude": ["node_modules"] -} \ No newline at end of file +} diff --git a/cspell.json b/cspell.json index 36bba737cd7..15bba2cb5f8 100644 --- a/cspell.json +++ b/cspell.json @@ -106,6 +106,7 @@ "Guillaume", "gzipped", "hasher", + "heaptrack", "hexdigit", "higher-kinded", "Hindley-Milner", diff --git a/docs/docs/how_to/debugger/debugging_with_the_repl.md b/docs/docs/how_to/debugger/debugging_with_the_repl.md index 09e5bae68ad..1d64dae3f37 100644 --- a/docs/docs/how_to/debugger/debugging_with_the_repl.md +++ b/docs/docs/how_to/debugger/debugging_with_the_repl.md @@ -1,7 +1,7 @@ --- title: Using the REPL Debugger description: - Step by step guide on how to debug your Noir circuits with the REPL Debugger. + Step-by-step guide on how to debug your Noir circuits with the REPL Debugger. keywords: [ Nargo, diff --git a/docs/docs/how_to/debugger/debugging_with_vs_code.md b/docs/docs/how_to/debugger/debugging_with_vs_code.md index a5858c1a5eb..8bda93324f5 100644 --- a/docs/docs/how_to/debugger/debugging_with_vs_code.md +++ b/docs/docs/how_to/debugger/debugging_with_vs_code.md @@ -1,7 +1,7 @@ --- title: Using the VS Code Debugger description: - Step by step guide on how to debug your Noir circuits with the VS Code Debugger configuration and features. + Step-by-step guide on how to debug your Noir circuits with the VS Code Debugger configuration and features. keywords: [ Nargo, @@ -65,4 +65,4 @@ We just need to click the to the right of the line number 18. Once the breakpoin Now we are debugging the `keccak256` function, notice the _Call Stack pane_ at the lower right. This lets us inspect the current call stack of our process. -That covers most of the current debugger functionalities. Check out [the reference](../../reference/debugger/debugger_vscode.md) for more details on how to configure the debugger. \ No newline at end of file +That covers most of the current debugger functionalities. Check out [the reference](../../reference/debugger/debugger_vscode.md) for more details on how to configure the debugger. diff --git a/docs/docs/noir/concepts/data_types/integers.md b/docs/docs/noir/concepts/data_types/integers.md index a1d59bf3166..f3badde62be 100644 --- a/docs/docs/noir/concepts/data_types/integers.md +++ b/docs/docs/noir/concepts/data_types/integers.md @@ -47,6 +47,17 @@ fn main() { The bit size determines the maximum and minimum range of value the integer type can store. For example, an `i8` variable can store a value in the range of -128 to 127 (i.e. $\\-2^{7}\\$ to $\\2^{7}-1\\$). + +```rust +fn main(x: i16, y: i16) { + // modulo + let c = x % y; + let c = x % -13; +} +``` + +Modulo operation is defined for negative integers thanks to integer division, so that the equality `x = (x/y)*y + (x%y)` holds. + ## 128 bits Unsigned Integers The built-in structure `U128` allows you to use 128-bit unsigned integers almost like a native integer type. However, there are some differences to keep in mind: diff --git a/docs/docs/noir/modules_packages_crates/dependencies.md b/docs/docs/noir/modules_packages_crates/dependencies.md index 24e02de08fe..22186b22598 100644 --- a/docs/docs/noir/modules_packages_crates/dependencies.md +++ b/docs/docs/noir/modules_packages_crates/dependencies.md @@ -81,12 +81,10 @@ use std::hash::sha256; use std::scalar_mul::fixed_base_embedded_curve; ``` -Lastly, as demonstrated in the -[elliptic curve example](../standard_library/cryptographic_primitives/ec_primitives.md#examples), you -can import multiple items in the same line by enclosing them in curly braces: +Lastly, You can import multiple items in the same line by enclosing them in curly braces: ```rust -use std::ec::tecurve::affine::{Curve, Point}; +use std::hash::{keccak256, sha256}; ``` We don't have a way to consume libraries from inside a [workspace](./workspaces.md) as external dependencies right now. diff --git a/docs/docs/noir/standard_library/cryptographic_primitives/ec_primitives.md b/docs/docs/noir/standard_library/cryptographic_primitives/ec_primitives.md deleted file mode 100644 index 00b8071487e..00000000000 --- a/docs/docs/noir/standard_library/cryptographic_primitives/ec_primitives.md +++ /dev/null @@ -1,101 +0,0 @@ ---- -title: Elliptic Curve Primitives -keywords: [cryptographic primitives, Noir project] -sidebar_position: 4 ---- - -Data structures and methods on them that allow you to carry out computations involving elliptic -curves over the (mathematical) field corresponding to `Field`. For the field currently at our -disposal, applications would involve a curve embedded in BN254, e.g. the -[Baby Jubjub curve](https://eips.ethereum.org/EIPS/eip-2494). - -## Data structures - -### Elliptic curve configurations - -(`std::ec::{tecurve,montcurve,swcurve}::{affine,curvegroup}::Curve`), i.e. the specific elliptic -curve you want to use, which would be specified using any one of the methods -`std::ec::{tecurve,montcurve,swcurve}::{affine,curvegroup}::new` which take the coefficients in the -defining equation together with a generator point as parameters. You can find more detail in the -comments in -[`noir_stdlib/src/ec/mod.nr`](https://github.com/noir-lang/noir/blob/master/noir_stdlib/src/ec/mod.nr), but -the gist of it is that the elliptic curves of interest are usually expressed in one of the standard -forms implemented here (Twisted Edwards, Montgomery and Short Weierstraß), and in addition to that, -you could choose to use `affine` coordinates (Cartesian coordinates - the usual (x,y) - possibly -together with a point at infinity) or `curvegroup` coordinates (some form of projective coordinates -requiring more coordinates but allowing for more efficient implementations of elliptic curve -operations). Conversions between all of these forms are provided, and under the hood these -conversions are done whenever an operation is more efficient in a different representation (or a -mixed coordinate representation is employed). - -### Points - -(`std::ec::{tecurve,montcurve,swcurve}::{affine,curvegroup}::Point`), i.e. points lying on the -elliptic curve. For a curve configuration `c` and a point `p`, it may be checked that `p` -does indeed lie on `c` by calling `c.contains(p1)`. - -## Methods - -(given a choice of curve representation, e.g. use `std::ec::tecurve::affine::Curve` and use -`std::ec::tecurve::affine::Point`) - -- The **zero element** is given by `Point::zero()`, and we can verify whether a point `p: Point` is - zero by calling `p.is_zero()`. -- **Equality**: Points `p1: Point` and `p2: Point` may be checked for equality by calling - `p1.eq(p2)`. -- **Addition**: For `c: Curve` and points `p1: Point` and `p2: Point` on the curve, adding these two - points is accomplished by calling `c.add(p1,p2)`. -- **Negation**: For a point `p: Point`, `p.negate()` is its negation. -- **Subtraction**: For `c` and `p1`, `p2` as above, subtracting `p2` from `p1` is accomplished by - calling `c.subtract(p1,p2)`. -- **Scalar multiplication**: For `c` as above, `p: Point` a point on the curve and `n: Field`, - scalar multiplication is given by `c.mul(n,p)`. If instead `n :: [u1; N]`, i.e. `n` is a bit - array, the `bit_mul` method may be used instead: `c.bit_mul(n,p)` -- **Multi-scalar multiplication**: For `c` as above and arrays `n: [Field; N]` and `p: [Point; N]`, - multi-scalar multiplication is given by `c.msm(n,p)`. -- **Coordinate representation conversions**: The `into_group` method converts a point or curve - configuration in the affine representation to one in the CurveGroup representation, and - `into_affine` goes in the other direction. -- **Curve representation conversions**: `tecurve` and `montcurve` curves and points are equivalent - and may be converted between one another by calling `into_montcurve` or `into_tecurve` on their - configurations or points. `swcurve` is more general and a curve c of one of the other two types - may be converted to this representation by calling `c.into_swcurve()`, whereas a point `p` lying - on the curve given by `c` may be mapped to its corresponding `swcurve` point by calling - `c.map_into_swcurve(p)`. -- **Map-to-curve methods**: The Elligator 2 method of mapping a field element `n: Field` into a - `tecurve` or `montcurve` with configuration `c` may be called as `c.elligator2_map(n)`. For all of - the curve configurations, the SWU map-to-curve method may be called as `c.swu_map(z,n)`, where - `z: Field` depends on `Field` and `c` and must be chosen by the user (the conditions it needs to - satisfy are specified in the comments - [here](https://github.com/noir-lang/noir/blob/master/noir_stdlib/src/ec/mod.nr)). - -## Examples - -The -[ec_baby_jubjub test](https://github.com/noir-lang/noir/blob/master/test_programs/compile_success_empty/ec_baby_jubjub/src/main.nr) -illustrates all of the above primitives on various forms of the Baby Jubjub curve. A couple of more -interesting examples in Noir would be: - -Public-key cryptography: Given an elliptic curve and a 'base point' on it, determine the public key -from the private key. This is a matter of using scalar multiplication. In the case of Baby Jubjub, -for example, this code would do: - -```rust -use std::ec::tecurve::affine::{Curve, Point}; - -fn bjj_pub_key(priv_key: Field) -> Point -{ - - let bjj = Curve::new(168700, 168696, G::new(995203441582195749578291179787384436505546430278305826713579947235728471134,5472060717959818805561601436314318772137091100104008585924551046643952123905)); - - let base_pt = Point::new(5299619240641551281634865583518297030282874472190772894086521144482721001553, 16950150798460657717958625567821834550301663161624707787222815936182638968203); - - bjj.mul(priv_key,base_pt) -} -``` - -This would come in handy in a Merkle proof. - -- EdDSA signature verification: This is a matter of combining these primitives with a suitable hash - function. See the [eddsa](https://github.com/noir-lang/eddsa) library an example of eddsa signature verification - over the Baby Jubjub curve. diff --git a/docs/docs/noir/standard_library/mem.md b/docs/docs/noir/standard_library/mem.md index 3619550273e..1e9102b32dc 100644 --- a/docs/docs/noir/standard_library/mem.md +++ b/docs/docs/noir/standard_library/mem.md @@ -42,7 +42,7 @@ fn checked_transmute(value: T) -> U Transmutes a value of one type into the same value but with a new type `U`. This function is safe to use since both types are asserted to be equal later during compilation after the concrete values for generic types become known. -This function is useful for cases where the compiler may fails a type check that is expected to pass where +This function is useful for cases where the compiler may fail a type check that is expected to pass where a user knows the two types to be equal. For example, when using arithmetic generics there are cases the compiler does not see as equal, such as `[Field; N*(A + B)]` and `[Field; N*A + N*B]`, which users may know to be equal. In these cases, `checked_transmute` can be used to cast the value to the desired type while also preserving safety diff --git a/docs/docs/noir/standard_library/meta/index.md b/docs/docs/noir/standard_library/meta/index.md index db0e5d0e411..76daa594b1f 100644 --- a/docs/docs/noir/standard_library/meta/index.md +++ b/docs/docs/noir/standard_library/meta/index.md @@ -128,7 +128,7 @@ way to write your derive handler. The arguments are as follows: - `for_each_field`: An operation to be performed on each field. E.g. `|name| quote { (self.$name == other.$name) }`. - `join_fields_with`: A separator to join each result of `for_each_field` with. E.g. `quote { & }`. You can also use an empty `quote {}` for no separator. -- `body`: The result of the field operations are passed into this function for any final processing. +- `body`: The result of the field operations is passed into this function for any final processing. This is the place to insert any setup/teardown code the trait requires. If the trait doesn't require any such code, you can return the body as-is: `|body| body`. diff --git a/docs/docs/noir/standard_library/meta/typ.md b/docs/docs/noir/standard_library/meta/typ.md index 71a36e629c6..455853bfea3 100644 --- a/docs/docs/noir/standard_library/meta/typ.md +++ b/docs/docs/noir/standard_library/meta/typ.md @@ -101,7 +101,7 @@ If this is a tuple type, returns each element type of the tuple. Retrieves the trait implementation that implements the given trait constraint for this type. If the trait constraint is not found, `None` is returned. Note that since the concrete trait implementation -for a trait constraint specified from a `where` clause is unknown, +for a trait constraint specified in a `where` clause is unknown, this function will return `None` in these cases. If you only want to know whether a type implements a trait, use `implements` instead. diff --git a/docs/versioned_docs/version-v0.32.0/how_to/debugger/debugging_with_the_repl.md b/docs/versioned_docs/version-v0.32.0/how_to/debugger/debugging_with_the_repl.md index 09e5bae68ad..1d64dae3f37 100644 --- a/docs/versioned_docs/version-v0.32.0/how_to/debugger/debugging_with_the_repl.md +++ b/docs/versioned_docs/version-v0.32.0/how_to/debugger/debugging_with_the_repl.md @@ -1,7 +1,7 @@ --- title: Using the REPL Debugger description: - Step by step guide on how to debug your Noir circuits with the REPL Debugger. + Step-by-step guide on how to debug your Noir circuits with the REPL Debugger. keywords: [ Nargo, diff --git a/docs/versioned_docs/version-v0.33.0/how_to/debugger/debugging_with_the_repl.md b/docs/versioned_docs/version-v0.33.0/how_to/debugger/debugging_with_the_repl.md index 09e5bae68ad..1d64dae3f37 100644 --- a/docs/versioned_docs/version-v0.33.0/how_to/debugger/debugging_with_the_repl.md +++ b/docs/versioned_docs/version-v0.33.0/how_to/debugger/debugging_with_the_repl.md @@ -1,7 +1,7 @@ --- title: Using the REPL Debugger description: - Step by step guide on how to debug your Noir circuits with the REPL Debugger. + Step-by-step guide on how to debug your Noir circuits with the REPL Debugger. keywords: [ Nargo, diff --git a/docs/versioned_docs/version-v0.34.0/how_to/debugger/debugging_with_the_repl.md b/docs/versioned_docs/version-v0.34.0/how_to/debugger/debugging_with_the_repl.md index 09e5bae68ad..1d64dae3f37 100644 --- a/docs/versioned_docs/version-v0.34.0/how_to/debugger/debugging_with_the_repl.md +++ b/docs/versioned_docs/version-v0.34.0/how_to/debugger/debugging_with_the_repl.md @@ -1,7 +1,7 @@ --- title: Using the REPL Debugger description: - Step by step guide on how to debug your Noir circuits with the REPL Debugger. + Step-by-step guide on how to debug your Noir circuits with the REPL Debugger. keywords: [ Nargo, diff --git a/docs/versioned_docs/version-v0.35.0/how_to/debugger/debugging_with_the_repl.md b/docs/versioned_docs/version-v0.35.0/how_to/debugger/debugging_with_the_repl.md index 09e5bae68ad..1d64dae3f37 100644 --- a/docs/versioned_docs/version-v0.35.0/how_to/debugger/debugging_with_the_repl.md +++ b/docs/versioned_docs/version-v0.35.0/how_to/debugger/debugging_with_the_repl.md @@ -1,7 +1,7 @@ --- title: Using the REPL Debugger description: - Step by step guide on how to debug your Noir circuits with the REPL Debugger. + Step-by-step guide on how to debug your Noir circuits with the REPL Debugger. keywords: [ Nargo, diff --git a/docs/versioned_docs/version-v0.36.0/how_to/debugger/debugging_with_the_repl.md b/docs/versioned_docs/version-v0.36.0/how_to/debugger/debugging_with_the_repl.md index 09e5bae68ad..1d64dae3f37 100644 --- a/docs/versioned_docs/version-v0.36.0/how_to/debugger/debugging_with_the_repl.md +++ b/docs/versioned_docs/version-v0.36.0/how_to/debugger/debugging_with_the_repl.md @@ -1,7 +1,7 @@ --- title: Using the REPL Debugger description: - Step by step guide on how to debug your Noir circuits with the REPL Debugger. + Step-by-step guide on how to debug your Noir circuits with the REPL Debugger. keywords: [ Nargo, diff --git a/docs/versioned_docs/version-v0.37.0/how_to/debugger/debugging_with_the_repl.md b/docs/versioned_docs/version-v0.37.0/how_to/debugger/debugging_with_the_repl.md index 09e5bae68ad..1d64dae3f37 100644 --- a/docs/versioned_docs/version-v0.37.0/how_to/debugger/debugging_with_the_repl.md +++ b/docs/versioned_docs/version-v0.37.0/how_to/debugger/debugging_with_the_repl.md @@ -1,7 +1,7 @@ --- title: Using the REPL Debugger description: - Step by step guide on how to debug your Noir circuits with the REPL Debugger. + Step-by-step guide on how to debug your Noir circuits with the REPL Debugger. keywords: [ Nargo, diff --git a/docs/versioned_docs/version-v0.38.0/how_to/debugger/debugging_with_the_repl.md b/docs/versioned_docs/version-v0.38.0/how_to/debugger/debugging_with_the_repl.md index 09e5bae68ad..1d64dae3f37 100644 --- a/docs/versioned_docs/version-v0.38.0/how_to/debugger/debugging_with_the_repl.md +++ b/docs/versioned_docs/version-v0.38.0/how_to/debugger/debugging_with_the_repl.md @@ -1,7 +1,7 @@ --- title: Using the REPL Debugger description: - Step by step guide on how to debug your Noir circuits with the REPL Debugger. + Step-by-step guide on how to debug your Noir circuits with the REPL Debugger. keywords: [ Nargo, diff --git a/docs/versioned_docs/version-v0.39.0/how_to/debugger/debugging_with_the_repl.md b/docs/versioned_docs/version-v0.39.0/how_to/debugger/debugging_with_the_repl.md index 09e5bae68ad..1d64dae3f37 100644 --- a/docs/versioned_docs/version-v0.39.0/how_to/debugger/debugging_with_the_repl.md +++ b/docs/versioned_docs/version-v0.39.0/how_to/debugger/debugging_with_the_repl.md @@ -1,7 +1,7 @@ --- title: Using the REPL Debugger description: - Step by step guide on how to debug your Noir circuits with the REPL Debugger. + Step-by-step guide on how to debug your Noir circuits with the REPL Debugger. keywords: [ Nargo, diff --git a/docs/versioned_docs/version-v1.0.0-beta.0/how_to/debugger/debugging_with_the_repl.md b/docs/versioned_docs/version-v1.0.0-beta.0/how_to/debugger/debugging_with_the_repl.md index 09e5bae68ad..1d64dae3f37 100644 --- a/docs/versioned_docs/version-v1.0.0-beta.0/how_to/debugger/debugging_with_the_repl.md +++ b/docs/versioned_docs/version-v1.0.0-beta.0/how_to/debugger/debugging_with_the_repl.md @@ -1,7 +1,7 @@ --- title: Using the REPL Debugger description: - Step by step guide on how to debug your Noir circuits with the REPL Debugger. + Step-by-step guide on how to debug your Noir circuits with the REPL Debugger. keywords: [ Nargo, diff --git a/noir_stdlib/src/ec/consts/mod.nr b/noir_stdlib/src/ec/consts/mod.nr deleted file mode 100644 index 73c594c6a26..00000000000 --- a/noir_stdlib/src/ec/consts/mod.nr +++ /dev/null @@ -1 +0,0 @@ -pub mod te; diff --git a/noir_stdlib/src/ec/consts/te.nr b/noir_stdlib/src/ec/consts/te.nr deleted file mode 100644 index 150eb849947..00000000000 --- a/noir_stdlib/src/ec/consts/te.nr +++ /dev/null @@ -1,33 +0,0 @@ -use crate::ec::tecurve::affine::Curve as TECurve; -use crate::ec::tecurve::affine::Point as TEPoint; - -pub struct BabyJubjub { - pub curve: TECurve, - pub base8: TEPoint, - pub suborder: Field, -} - -#[field(bn254)] -// Uncommenting this results in deprecated warnings in the stdlib -// #[deprecated] -pub fn baby_jubjub() -> BabyJubjub { - BabyJubjub { - // Baby Jubjub (ERC-2494) parameters in affine representation - curve: TECurve::new( - 168700, - 168696, - // G - TEPoint::new( - 995203441582195749578291179787384436505546430278305826713579947235728471134, - 5472060717959818805561601436314318772137091100104008585924551046643952123905, - ), - ), - // [8]G precalculated - base8: TEPoint::new( - 5299619240641551281634865583518297030282874472190772894086521144482721001553, - 16950150798460657717958625567821834550301663161624707787222815936182638968203, - ), - // The size of the group formed from multiplying the base field by 8. - suborder: 2736030358979909402780800718157159386076813972158567259200215660948447373041, - } -} diff --git a/noir_stdlib/src/ec/mod.nr b/noir_stdlib/src/ec/mod.nr deleted file mode 100644 index e2b2df1b794..00000000000 --- a/noir_stdlib/src/ec/mod.nr +++ /dev/null @@ -1,199 +0,0 @@ -// Elliptic curve implementation -// Overview -// ======== -// The following three elliptic curve representations are admissible: -pub mod tecurve; // Twisted Edwards curves -pub mod swcurve; // Elliptic curves in Short Weierstrass form -pub mod montcurve; // Montgomery curves -pub mod consts; // Commonly used curve presets -// -// Note that Twisted Edwards and Montgomery curves are (birationally) equivalent, so that -// they may be freely converted between one another, whereas Short Weierstrass curves are -// more general. Diagramatically: -// -// tecurve == montcurve `subset` swcurve -// -// Each module is further divided into two submodules, 'affine' and 'curvegroup', depending -// on the preferred coordinate representation. Affine coordinates are none other than the usual -// two-dimensional Cartesian coordinates used in the definitions of these curves, whereas -// 'CurveGroup' coordinates (terminology borrowed from Arkworks, whose conventions we try -// to follow) are special coordinate systems with respect to which the group operations may be -// implemented more efficiently, usually by means of an appropriate choice of projective coordinates. -// -// In each of these submodules, there is a Point struct and a Curve struct, the former -// representing a point in the coordinate system and the latter a curve configuration. -// -// Points -// ====== -// Points may be instantiated using the associated function `new`, which takes coordinates -// as its arguments. For instance, -// -// `let p = swcurve::Point::new(1,1);` -// -// The additive identity may be constructed by a call to the associated function `zero` of no -// arguments: -// -// `let zero = swcurve::Point::zero();` -// -// Points may be tested for equality by calling the method `eq`: -// -// `let pred = p.eq(zero);` -// -// There is also the method `is_zero` to explicitly check whether a point is the additive identity: -// -// `constrain pred == p.is_zero();` -// -// Points may be negated by calling the `negate` method and converted to CurveGroup (or affine) -// coordinates by calling the `into_group` (resp. `into_affine`) method on them. Finally, -// Points may be freely mapped between their respective Twisted Edwards and Montgomery -// representations by calling the `into_montcurve` or `into_tecurve` methods. For mappings -// between Twisted Edwards/Montgomery curves and Short Weierstrass curves, see the Curve section -// below, as the underlying mappings are those of curves rather than ambient spaces. -// As a rule, Points in affine (or CurveGroup) coordinates are mapped to Points in affine -// (resp. CurveGroup) coordinates. -// -// Curves -// ====== -// A curve configuration (Curve) is completely determined by the Field coefficients of its defining -// equation (a and b in the case of swcurve, a and d in the case of tecurve, and j and k in -// the case of montcurve) together with a generator (`gen`) in the corresponding coordinate system. -// For example, the Baby Jubjub curve configuration as defined in ERC-2494 may be instantiated as a Twisted -// Edwards curve in affine coordinates as follows: -// -// `let bjj_affine = tecurve::Curve::new(168700, 168696, tecurve::Point::new(995203441582195749578291179787384436505546430278305826713579947235728471134,5472060717959818805561601436314318772137091100104008585924551046643952123905));` -// -// The `contains` method may be used to check whether a Point lies on a given curve: -// -// `constrain bjj_affine.contains(tecurve::Point::zero());` -// -// The elliptic curve group's addition operation is exposed as the `add` method, e.g. -// -// `let p = bjj_affine.add(bjj_affine.gen, bjj_affine.gen);` -// -// subtraction as the `subtract` method, e.g. -// -// `constrain tecurve::Point::zero().eq(bjj_affine.subtract(bjj_affine.gen, bjj_affine.gen));` -// -// scalar multiplication as the `mul` method, where the scalar is assumed to be a Field* element, e.g. -// -// `constrain tecurve::Point::zero().eq(bjj_affine.mul(2, tecurve::Point::zero());` -// -// There is a scalar multiplication method (`bit_mul`) provided where the scalar input is expected to be -// an array of bits (little-endian convention), as well as a multi-scalar multiplication method** (`msm`) -// which takes an array of Field elements and an array of elliptic curve points as arguments, both assumed -// to be of the same length. -// -// Curve configurations may be converted between different coordinate representations by calling the `into_group` -// and `into_affine` methods on them, e.g. -// -// `let bjj_curvegroup = bjj_affine.into_group();` -// -// Curve configurations may also be converted between different curve representations by calling the `into_swcurve`, -// `into_montcurve` and `into_tecurve` methods subject to the relation between the curve representations mentioned -// above. Note that it is possible to map Points from a Twisted Edwards/Montgomery curve to the corresponding -// Short Weierstrass representation and back, and the methods to do so are exposed as `map_into_swcurve` and -// `map_from_swcurve`, which each take one argument, the point to be mapped. -// -// Curve maps -// ========== -// There are a few different ways of mapping Field elements to elliptic curves. Here we provide the simplified -// Shallue-van de Woestijne-Ulas and Elligator 2 methods, the former being applicable to all curve types -// provided above subject to the constraint that the coefficients of the corresponding Short Weierstrass curve satisfies -// a*b != 0 and the latter being applicable to Montgomery and Twisted Edwards curves subject to the constraint that -// the coefficients of the corresponding Montgomery curve satisfy j*k != 0 and (j^2 - 4)/k^2 is non-square. -// -// The simplified Shallue-van de Woestijne-Ulas method is exposed as the method `swu_map` on the Curve configuration and -// depends on two parameters, a Field element z != -1 for which g(x) - z is irreducible over Field and g(b/(z*a)) is -// square, where g(x) = x^3 + a*x + b is the right-hand side of the defining equation of the corresponding Short -// Weierstrass curve, and a Field element u to be mapped onto the curve. For example, in the case of bjj_affine above, -// it may be determined using the scripts provided at that z = 5. -// -// The Elligator 2 method is exposed as the method `elligator2_map` on the Curve configurations of Montgomery and -// Twisted Edwards curves. Like the simplified SWU method above, it depends on a certain non-square element of Field, -// but this element need not satisfy any further conditions, so it is included as the (Field-dependent) constant -//`ZETA` below. Thus, the `elligator2_map` method depends only on one parameter, the Field element to be mapped onto -// the curve. -// -// For details on all of the above in the context of hashing to elliptic curves, see . -// -// -// *TODO: Replace Field with Bigint. -// **TODO: Support arrays of structs to make this work. -// Field-dependent constant ZETA = a non-square element of Field -// Required for Elligator 2 map -// TODO: Replace with built-in constant. -global ZETA: Field = 5; -// Field-dependent constants for Tonelli-Shanks algorithm (see sqrt function below) -// TODO: Possibly make this built-in. -global C1: u32 = 28; -global C3: Field = 40770029410420498293352137776570907027550720424234931066070132305055; -global C5: Field = 19103219067921713944291392827692070036145651957329286315305642004821462161904; -// Higher-order version of scalar multiplication -// TODO: Make this work so that the submodules' bit_mul may be defined in terms of it. -//fn bit_mul(add: fn(T,T) -> T, e: T, bits: [u1; N], p: T) -> T { -// let mut out = e; -// let n = bits.len(); -// -// for i in 0..n { -// out = add( -// add(out, out), -// if(bits[n - i - 1] == 0) {e} else {p}); -// } -// -// out -//} -// TODO: Make this built-in. -pub fn safe_inverse(x: Field) -> Field { - if x == 0 { - 0 - } else { - 1 / x - } -} -// Boolean indicating whether Field element is a square, i.e. whether there exists a y in Field s.t. x = y*y. -pub fn is_square(x: Field) -> bool { - let v = pow(x, 0 - 1 / 2); - - v * (v - 1) == 0 -} -// Power function of two Field arguments of arbitrary size. -// Adapted from std::field::pow_32. -pub fn pow(x: Field, y: Field) -> Field { - let mut r = 1 as Field; - let b: [u1; 254] = y.to_le_bits(); - - for i in 0..254 { - r *= r; - r *= (b[254 - 1 - i] as Field) * x + (1 - b[254 - 1 - i] as Field); - } - - r -} -// Tonelli-Shanks algorithm for computing the square root of a Field element. -// Requires C1 = max{c: 2^c divides (p-1)}, where p is the order of Field -// as well as C3 = (C2 - 1)/2, where C2 = (p-1)/(2^c1), -// and C5 = ZETA^C2, where ZETA is a non-square element of Field. -// These are pre-computed above as globals. -pub fn sqrt(x: Field) -> Field { - let mut z = pow(x, C3); - let mut t = z * z * x; - z *= x; - let mut b = t; - let mut c = C5; - - for i in 0..(C1 - 1) { - for _j in 1..(C1 - i - 1) { - b *= b; - } - - z *= if b == 1 { 1 } else { c }; - - c *= c; - - t *= if b == 1 { 1 } else { c }; - - b = t; - } - - z -} diff --git a/noir_stdlib/src/ec/montcurve.nr b/noir_stdlib/src/ec/montcurve.nr deleted file mode 100644 index 239585ba13f..00000000000 --- a/noir_stdlib/src/ec/montcurve.nr +++ /dev/null @@ -1,387 +0,0 @@ -pub mod affine { - // Affine representation of Montgomery curves - // Points are represented by two-dimensional Cartesian coordinates. - // All group operations are induced by those of the corresponding Twisted Edwards curve. - // See e.g. for details on the correspondences. - use crate::cmp::Eq; - use crate::ec::is_square; - use crate::ec::montcurve::curvegroup; - use crate::ec::safe_inverse; - use crate::ec::sqrt; - use crate::ec::swcurve::affine::Curve as SWCurve; - use crate::ec::swcurve::affine::Point as SWPoint; - use crate::ec::tecurve::affine::Curve as TECurve; - use crate::ec::tecurve::affine::Point as TEPoint; - use crate::ec::ZETA; - - // Curve specification - pub struct Curve { // Montgomery Curve configuration (ky^2 = x^3 + j*x^2 + x) - pub j: Field, - pub k: Field, - // Generator as point in Cartesian coordinates - pub gen: Point, - } - // Point in Cartesian coordinates - pub struct Point { - pub x: Field, - pub y: Field, - pub infty: bool, // Indicator for point at infinity - } - - impl Point { - // Point constructor - pub fn new(x: Field, y: Field) -> Self { - Self { x, y, infty: false } - } - - // Check if zero - pub fn is_zero(self) -> bool { - self.infty - } - - // Conversion to CurveGroup coordinates - pub fn into_group(self) -> curvegroup::Point { - if self.is_zero() { - curvegroup::Point::zero() - } else { - let (x, y) = (self.x, self.y); - curvegroup::Point::new(x, y, 1) - } - } - - // Additive identity - pub fn zero() -> Self { - Self { x: 0, y: 0, infty: true } - } - - // Negation - pub fn negate(self) -> Self { - let Self { x, y, infty } = self; - - Self { x, y: 0 - y, infty } - } - - // Map into equivalent Twisted Edwards curve - pub fn into_tecurve(self) -> TEPoint { - let Self { x, y, infty } = self; - - if infty | (y * (x + 1) == 0) { - TEPoint::zero() - } else { - TEPoint::new(x / y, (x - 1) / (x + 1)) - } - } - } - - impl Eq for Point { - fn eq(self, p: Self) -> bool { - (self.infty & p.infty) | (!self.infty & !p.infty & (self.x == p.x) & (self.y == p.y)) - } - } - - impl Curve { - // Curve constructor - pub fn new(j: Field, k: Field, gen: Point) -> Self { - // Check curve coefficients - assert(k != 0); - assert(j * j != 4); - - let curve = Self { j, k, gen }; - - // gen should be on the curve - assert(curve.contains(curve.gen)); - - curve - } - - // Conversion to CurveGroup coordinates - pub fn into_group(self) -> curvegroup::Curve { - curvegroup::Curve::new(self.j, self.k, self.gen.into_group()) - } - - // Membership check - pub fn contains(self, p: Point) -> bool { - let Self { j, k, gen: _gen } = self; - let Point { x, y, infty } = p; - - infty | (k * y * y == x * (x * x + j * x + 1)) - } - - // Point addition - pub fn add(self, p1: Point, p2: Point) -> Point { - self.into_tecurve().add(p1.into_tecurve(), p2.into_tecurve()).into_montcurve() - } - - // Scalar multiplication with scalar represented by a bit array (little-endian convention). - // If k is the natural number represented by `bits`, then this computes p + ... + p k times. - pub fn bit_mul(self, bits: [u1; N], p: Point) -> Point { - self.into_tecurve().bit_mul(bits, p.into_tecurve()).into_montcurve() - } - - // Scalar multiplication (p + ... + p n times) - pub fn mul(self, n: Field, p: Point) -> Point { - self.into_tecurve().mul(n, p.into_tecurve()).into_montcurve() - } - - // Multi-scalar multiplication (n[0]*p[0] + ... + n[N]*p[N], where * denotes scalar multiplication) - pub fn msm(self, n: [Field; N], p: [Point; N]) -> Point { - let mut out = Point::zero(); - - for i in 0..N { - out = self.add(out, self.mul(n[i], p[i])); - } - - out - } - - // Point subtraction - pub fn subtract(self, p1: Point, p2: Point) -> Point { - self.add(p1, p2.negate()) - } - - // Conversion to equivalent Twisted Edwards curve - pub fn into_tecurve(self) -> TECurve { - let Self { j, k, gen } = self; - TECurve::new((j + 2) / k, (j - 2) / k, gen.into_tecurve()) - } - - // Conversion to equivalent Short Weierstrass curve - pub fn into_swcurve(self) -> SWCurve { - let j = self.j; - let k = self.k; - let a0 = (3 - j * j) / (3 * k * k); - let b0 = (2 * j * j * j - 9 * j) / (27 * k * k * k); - - SWCurve::new(a0, b0, self.map_into_swcurve(self.gen)) - } - - // Point mapping into equivalent Short Weierstrass curve - pub fn map_into_swcurve(self, p: Point) -> SWPoint { - if p.is_zero() { - SWPoint::zero() - } else { - SWPoint::new((3 * p.x + self.j) / (3 * self.k), p.y / self.k) - } - } - - // Point mapping from equivalent Short Weierstrass curve - pub fn map_from_swcurve(self, p: SWPoint) -> Point { - let SWPoint { x, y, infty } = p; - let j = self.j; - let k = self.k; - - Point { x: (3 * k * x - j) / 3, y: y * k, infty } - } - - // Elligator 2 map-to-curve method; see . - pub fn elligator2_map(self, u: Field) -> Point { - let j = self.j; - let k = self.k; - let z = ZETA; // Non-square Field element required for map - // Check whether curve is admissible - assert(j != 0); - let l = (j * j - 4) / (k * k); - assert(l != 0); - assert(is_square(l) == false); - - let x1 = safe_inverse(1 + z * u * u) * (0 - (j / k)); - - let gx1 = x1 * x1 * x1 + (j / k) * x1 * x1 + x1 / (k * k); - let x2 = 0 - x1 - (j / k); - let gx2 = x2 * x2 * x2 + (j / k) * x2 * x2 + x2 / (k * k); - - let x = if is_square(gx1) { x1 } else { x2 }; - - let y = if is_square(gx1) { - let y0 = sqrt(gx1); - if y0.sgn0() == 1 { - y0 - } else { - 0 - y0 - } - } else { - let y0 = sqrt(gx2); - if y0.sgn0() == 0 { - y0 - } else { - 0 - y0 - } - }; - - Point::new(x * k, y * k) - } - - // SWU map-to-curve method (via rational map) - pub fn swu_map(self, z: Field, u: Field) -> Point { - self.map_from_swcurve(self.into_swcurve().swu_map(z, u)) - } - } -} -pub mod curvegroup { - // Affine representation of Montgomery curves - // Points are represented by three-dimensional projective (homogeneous) coordinates. - // All group operations are induced by those of the corresponding Twisted Edwards curve. - // See e.g. for details on the correspondences. - use crate::cmp::Eq; - use crate::ec::montcurve::affine; - use crate::ec::swcurve::curvegroup::Curve as SWCurve; - use crate::ec::swcurve::curvegroup::Point as SWPoint; - use crate::ec::tecurve::curvegroup::Curve as TECurve; - use crate::ec::tecurve::curvegroup::Point as TEPoint; - - pub struct Curve { // Montgomery Curve configuration (ky^2 z = x*(x^2 + j*x*z + z*z)) - pub j: Field, - pub k: Field, - // Generator as point in projective coordinates - pub gen: Point, - } - // Point in projective coordinates - pub struct Point { - pub x: Field, - pub y: Field, - pub z: Field, - } - - impl Point { - // Point constructor - pub fn new(x: Field, y: Field, z: Field) -> Self { - Self { x, y, z } - } - - // Check if zero - pub fn is_zero(self) -> bool { - self.z == 0 - } - - // Conversion to affine coordinates - pub fn into_affine(self) -> affine::Point { - if self.is_zero() { - affine::Point::zero() - } else { - let (x, y, z) = (self.x, self.y, self.z); - affine::Point::new(x / z, y / z) - } - } - - // Additive identity - pub fn zero() -> Self { - Self { x: 0, y: 1, z: 0 } - } - - // Negation - pub fn negate(self) -> Self { - let Self { x, y, z } = self; - - Point::new(x, 0 - y, z) - } - - // Map into equivalent Twisted Edwards curve - pub fn into_tecurve(self) -> TEPoint { - self.into_affine().into_tecurve().into_group() - } - } - - impl Eq for Point { - fn eq(self, p: Self) -> bool { - (self.z == p.z) - | (((self.x * self.z) == (p.x * p.z)) & ((self.y * self.z) == (p.y * p.z))) - } - } - - impl Curve { - // Curve constructor - pub fn new(j: Field, k: Field, gen: Point) -> Self { - // Check curve coefficients - assert(k != 0); - assert(j * j != 4); - - let curve = Self { j, k, gen }; - - // gen should be on the curve - assert(curve.contains(curve.gen)); - - curve - } - - // Conversion to affine coordinates - pub fn into_affine(self) -> affine::Curve { - affine::Curve::new(self.j, self.k, self.gen.into_affine()) - } - - // Membership check - pub fn contains(self, p: Point) -> bool { - let Self { j, k, gen: _gen } = self; - let Point { x, y, z } = p; - - k * y * y * z == x * (x * x + j * x * z + z * z) - } - - // Point addition - pub fn add(self, p1: Point, p2: Point) -> Point { - self.into_affine().add(p1.into_affine(), p2.into_affine()).into_group() - } - - // Scalar multiplication with scalar represented by a bit array (little-endian convention). - // If k is the natural number represented by `bits`, then this computes p + ... + p k times. - pub fn bit_mul(self, bits: [u1; N], p: Point) -> Point { - self.into_tecurve().bit_mul(bits, p.into_tecurve()).into_montcurve() - } - - // Scalar multiplication (p + ... + p n times) - pub fn mul(self, n: Field, p: Point) -> Point { - self.into_tecurve().mul(n, p.into_tecurve()).into_montcurve() - } - - // Multi-scalar multiplication (n[0]*p[0] + ... + n[N]*p[N], where * denotes scalar multiplication) - pub fn msm(self, n: [Field; N], p: [Point; N]) -> Point { - let mut out = Point::zero(); - - for i in 0..N { - out = self.add(out, self.mul(n[i], p[i])); - } - - out - } - - // Point subtraction - pub fn subtract(self, p1: Point, p2: Point) -> Point { - self.add(p1, p2.negate()) - } - - // Conversion to equivalent Twisted Edwards curve - pub fn into_tecurve(self) -> TECurve { - let Self { j, k, gen } = self; - TECurve::new((j + 2) / k, (j - 2) / k, gen.into_tecurve()) - } - - // Conversion to equivalent Short Weierstrass curve - pub fn into_swcurve(self) -> SWCurve { - let j = self.j; - let k = self.k; - let a0 = (3 - j * j) / (3 * k * k); - let b0 = (2 * j * j * j - 9 * j) / (27 * k * k * k); - - SWCurve::new(a0, b0, self.map_into_swcurve(self.gen)) - } - - // Point mapping into equivalent Short Weierstrass curve - pub fn map_into_swcurve(self, p: Point) -> SWPoint { - self.into_affine().map_into_swcurve(p.into_affine()).into_group() - } - - // Point mapping from equivalent Short Weierstrass curve - pub fn map_from_swcurve(self, p: SWPoint) -> Point { - self.into_affine().map_from_swcurve(p.into_affine()).into_group() - } - - // Elligator 2 map-to-curve method - pub fn elligator2_map(self, u: Field) -> Point { - self.into_affine().elligator2_map(u).into_group() - } - - // SWU map-to-curve method (via rational map) - pub fn swu_map(self, z: Field, u: Field) -> Point { - self.into_affine().swu_map(z, u).into_group() - } - } -} diff --git a/noir_stdlib/src/ec/swcurve.nr b/noir_stdlib/src/ec/swcurve.nr deleted file mode 100644 index d9c1cf8c8c7..00000000000 --- a/noir_stdlib/src/ec/swcurve.nr +++ /dev/null @@ -1,394 +0,0 @@ -pub mod affine { - // Affine representation of Short Weierstrass curves - // Points are represented by two-dimensional Cartesian coordinates. - // Group operations are implemented in terms of those in CurveGroup (in this case, extended Twisted Edwards) coordinates - // for reasons of efficiency, cf. . - use crate::cmp::Eq; - use crate::ec::is_square; - use crate::ec::safe_inverse; - use crate::ec::sqrt; - use crate::ec::swcurve::curvegroup; - - // Curve specification - pub struct Curve { // Short Weierstrass curve - // Coefficients in defining equation y^2 = x^3 + ax + b - pub a: Field, - pub b: Field, - // Generator as point in Cartesian coordinates - pub gen: Point, - } - // Point in Cartesian coordinates - pub struct Point { - pub x: Field, - pub y: Field, - pub infty: bool, // Indicator for point at infinity - } - - impl Point { - // Point constructor - pub fn new(x: Field, y: Field) -> Self { - Self { x, y, infty: false } - } - - // Check if zero - pub fn is_zero(self) -> bool { - self.eq(Point::zero()) - } - - // Conversion to CurveGroup coordinates - pub fn into_group(self) -> curvegroup::Point { - let Self { x, y, infty } = self; - - if infty { - curvegroup::Point::zero() - } else { - curvegroup::Point::new(x, y, 1) - } - } - - // Additive identity - pub fn zero() -> Self { - Self { x: 0, y: 0, infty: true } - } - - // Negation - pub fn negate(self) -> Self { - let Self { x, y, infty } = self; - Self { x, y: 0 - y, infty } - } - } - - impl Eq for Point { - fn eq(self, p: Self) -> bool { - let Self { x: x1, y: y1, infty: inf1 } = self; - let Self { x: x2, y: y2, infty: inf2 } = p; - - (inf1 & inf2) | (!inf1 & !inf2 & (x1 == x2) & (y1 == y2)) - } - } - - impl Curve { - // Curve constructor - pub fn new(a: Field, b: Field, gen: Point) -> Curve { - // Check curve coefficients - assert(4 * a * a * a + 27 * b * b != 0); - - let curve = Curve { a, b, gen }; - - // gen should be on the curve - assert(curve.contains(curve.gen)); - - curve - } - - // Conversion to CurveGroup coordinates - pub fn into_group(self) -> curvegroup::Curve { - let Curve { a, b, gen } = self; - - curvegroup::Curve { a, b, gen: gen.into_group() } - } - - // Membership check - pub fn contains(self, p: Point) -> bool { - let Point { x, y, infty } = p; - infty | (y * y == x * x * x + self.a * x + self.b) - } - - // Point addition, implemented in terms of mixed addition for reasons of efficiency - pub fn add(self, p1: Point, p2: Point) -> Point { - self.mixed_add(p1, p2.into_group()).into_affine() - } - - // Mixed point addition, i.e. first argument in affine, second in CurveGroup coordinates. - pub fn mixed_add(self, p1: Point, p2: curvegroup::Point) -> curvegroup::Point { - if p1.is_zero() { - p2 - } else if p2.is_zero() { - p1.into_group() - } else { - let Point { x: x1, y: y1, infty: _inf } = p1; - let curvegroup::Point { x: x2, y: y2, z: z2 } = p2; - let you1 = x1 * z2 * z2; - let you2 = x2; - let s1 = y1 * z2 * z2 * z2; - let s2 = y2; - - if you1 == you2 { - if s1 != s2 { - curvegroup::Point::zero() - } else { - self.into_group().double(p2) - } - } else { - let h = you2 - you1; - let r = s2 - s1; - let x3 = r * r - h * h * h - 2 * you1 * h * h; - let y3 = r * (you1 * h * h - x3) - s1 * h * h * h; - let z3 = h * z2; - - curvegroup::Point::new(x3, y3, z3) - } - } - } - - // Scalar multiplication with scalar represented by a bit array (little-endian convention). - // If k is the natural number represented by `bits`, then this computes p + ... + p k times. - pub fn bit_mul(self, bits: [u1; N], p: Point) -> Point { - self.into_group().bit_mul(bits, p.into_group()).into_affine() - } - - // Scalar multiplication (p + ... + p n times) - pub fn mul(self, n: Field, p: Point) -> Point { - self.into_group().mul(n, p.into_group()).into_affine() - } - - // Multi-scalar multiplication (n[0]*p[0] + ... + n[N]*p[N], where * denotes scalar multiplication) - pub fn msm(self, n: [Field; N], p: [Point; N]) -> Point { - let mut out = Point::zero(); - - for i in 0..N { - out = self.add(out, self.mul(n[i], p[i])); - } - - out - } - - // Point subtraction - pub fn subtract(self, p1: Point, p2: Point) -> Point { - self.add(p1, p2.negate()) - } - - // Simplified Shallue-van de Woestijne-Ulas map-to-curve method; see . - // First determine non-square z != -1 in Field s.t. g(x) - z irreducible over Field and g(b/(z*a)) is square, - // where g(x) = x^3 + a*x + b. swu_map(c,z,.) then maps a Field element to a point on curve c. - pub fn swu_map(self, z: Field, u: Field) -> Point { - // Check whether curve is admissible - assert(self.a * self.b != 0); - - let Curve { a, b, gen: _gen } = self; - - let tv1 = safe_inverse(z * z * u * u * u * u + u * u * z); - let x1 = if tv1 == 0 { - b / (z * a) - } else { - (0 - b / a) * (1 + tv1) - }; - let gx1 = x1 * x1 * x1 + a * x1 + b; - let x2 = z * u * u * x1; - let gx2 = x2 * x2 * x2 + a * x2 + b; - let (x, y) = if is_square(gx1) { - (x1, sqrt(gx1)) - } else { - (x2, sqrt(gx2)) - }; - Point::new(x, if u.sgn0() != y.sgn0() { 0 - y } else { y }) - } - } -} - -pub mod curvegroup { - // CurveGroup representation of Weierstrass curves - // Points are represented by three-dimensional Jacobian coordinates. - // See for details. - use crate::cmp::Eq; - use crate::ec::swcurve::affine; - - // Curve specification - pub struct Curve { // Short Weierstrass curve - // Coefficients in defining equation y^2 = x^3 + axz^4 + bz^6 - pub a: Field, - pub b: Field, - // Generator as point in Cartesian coordinates - pub gen: Point, - } - // Point in three-dimensional Jacobian coordinates - pub struct Point { - pub x: Field, - pub y: Field, - pub z: Field, // z = 0 corresponds to point at infinity. - } - - impl Point { - // Point constructor - pub fn new(x: Field, y: Field, z: Field) -> Self { - Self { x, y, z } - } - - // Check if zero - pub fn is_zero(self) -> bool { - self.eq(Point::zero()) - } - - // Conversion to affine coordinates - pub fn into_affine(self) -> affine::Point { - let Self { x, y, z } = self; - - if z == 0 { - affine::Point::zero() - } else { - affine::Point::new(x / (z * z), y / (z * z * z)) - } - } - - // Additive identity - pub fn zero() -> Self { - Self { x: 0, y: 0, z: 0 } - } - - // Negation - pub fn negate(self) -> Self { - let Self { x, y, z } = self; - Self { x, y: 0 - y, z } - } - } - - impl Eq for Point { - fn eq(self, p: Self) -> bool { - let Self { x: x1, y: y1, z: z1 } = self; - let Self { x: x2, y: y2, z: z2 } = p; - - ((z1 == 0) & (z2 == 0)) - | ( - (z1 != 0) - & (z2 != 0) - & (x1 * z2 * z2 == x2 * z1 * z1) - & (y1 * z2 * z2 * z2 == y2 * z1 * z1 * z1) - ) - } - } - - impl Curve { - // Curve constructor - pub fn new(a: Field, b: Field, gen: Point) -> Curve { - // Check curve coefficients - assert(4 * a * a * a + 27 * b * b != 0); - - let curve = Curve { a, b, gen }; - - // gen should be on the curve - assert(curve.contains(curve.gen)); - - curve - } - - // Conversion to affine coordinates - pub fn into_affine(self) -> affine::Curve { - let Curve { a, b, gen } = self; - - affine::Curve { a, b, gen: gen.into_affine() } - } - - // Membership check - pub fn contains(self, p: Point) -> bool { - let Point { x, y, z } = p; - if z == 0 { - true - } else { - y * y == x * x * x + self.a * x * z * z * z * z + self.b * z * z * z * z * z * z - } - } - - // Addition - pub fn add(self, p1: Point, p2: Point) -> Point { - if p1.is_zero() { - p2 - } else if p2.is_zero() { - p1 - } else { - let Point { x: x1, y: y1, z: z1 } = p1; - let Point { x: x2, y: y2, z: z2 } = p2; - let you1 = x1 * z2 * z2; - let you2 = x2 * z1 * z1; - let s1 = y1 * z2 * z2 * z2; - let s2 = y2 * z1 * z1 * z1; - - if you1 == you2 { - if s1 != s2 { - Point::zero() - } else { - self.double(p1) - } - } else { - let h = you2 - you1; - let r = s2 - s1; - let x3 = r * r - h * h * h - 2 * you1 * h * h; - let y3 = r * (you1 * h * h - x3) - s1 * h * h * h; - let z3 = h * z1 * z2; - - Point::new(x3, y3, z3) - } - } - } - - // Point doubling - pub fn double(self, p: Point) -> Point { - let Point { x, y, z } = p; - - if p.is_zero() { - p - } else if y == 0 { - Point::zero() - } else { - let s = 4 * x * y * y; - let m = 3 * x * x + self.a * z * z * z * z; - let x0 = m * m - 2 * s; - let y0 = m * (s - x0) - 8 * y * y * y * y; - let z0 = 2 * y * z; - - Point::new(x0, y0, z0) - } - } - - // Scalar multiplication with scalar represented by a bit array (little-endian convention). - // If k is the natural number represented by `bits`, then this computes p + ... + p k times. - pub fn bit_mul(self, bits: [u1; N], p: Point) -> Point { - let mut out = Point::zero(); - - for i in 0..N { - out = self.add( - self.add(out, out), - if (bits[N - i - 1] == 0) { - Point::zero() - } else { - p - }, - ); - } - - out - } - - // Scalar multiplication (p + ... + p n times) - pub fn mul(self, n: Field, p: Point) -> Point { - // TODO: temporary workaround until issue 1354 is solved - let mut n_as_bits: [u1; 254] = [0; 254]; - let tmp: [u1; 254] = n.to_le_bits(); - for i in 0..254 { - n_as_bits[i] = tmp[i]; - } - - self.bit_mul(n_as_bits, p) - } - - // Multi-scalar multiplication (n[0]*p[0] + ... + n[N]*p[N], where * denotes scalar multiplication) - pub fn msm(self, n: [Field; N], p: [Point; N]) -> Point { - let mut out = Point::zero(); - - for i in 0..N { - out = self.add(out, self.mul(n[i], p[i])); - } - - out - } - - // Point subtraction - pub fn subtract(self, p1: Point, p2: Point) -> Point { - self.add(p1, p2.negate()) - } - - // Simplified SWU map-to-curve method - pub fn swu_map(self, z: Field, u: Field) -> Point { - self.into_affine().swu_map(z, u).into_group() - } - } -} diff --git a/noir_stdlib/src/ec/tecurve.nr b/noir_stdlib/src/ec/tecurve.nr deleted file mode 100644 index 45a6b322ed1..00000000000 --- a/noir_stdlib/src/ec/tecurve.nr +++ /dev/null @@ -1,419 +0,0 @@ -pub mod affine { - // Affine coordinate representation of Twisted Edwards curves - // Points are represented by two-dimensional Cartesian coordinates. - // Group operations are implemented in terms of those in CurveGroup (in this case, extended Twisted Edwards) coordinates - // for reasons of efficiency. - // See for details. - use crate::cmp::Eq; - use crate::ec::montcurve::affine::Curve as MCurve; - use crate::ec::montcurve::affine::Point as MPoint; - use crate::ec::swcurve::affine::Curve as SWCurve; - use crate::ec::swcurve::affine::Point as SWPoint; - use crate::ec::tecurve::curvegroup; - - // Curve specification - pub struct Curve { // Twisted Edwards curve - // Coefficients in defining equation ax^2 + y^2 = 1 + dx^2y^2 - pub a: Field, - pub d: Field, - // Generator as point in Cartesian coordinates - pub gen: Point, - } - // Point in Cartesian coordinates - pub struct Point { - pub x: Field, - pub y: Field, - } - - impl Point { - // Point constructor - // #[deprecated("It's recommmended to use the external noir-edwards library (https://github.com/noir-lang/noir-edwards)")] - pub fn new(x: Field, y: Field) -> Self { - Self { x, y } - } - - // Check if zero - pub fn is_zero(self) -> bool { - self.eq(Point::zero()) - } - - // Conversion to CurveGroup coordinates - pub fn into_group(self) -> curvegroup::Point { - let Self { x, y } = self; - - curvegroup::Point::new(x, y, x * y, 1) - } - - // Additive identity - pub fn zero() -> Self { - Point::new(0, 1) - } - - // Negation - pub fn negate(self) -> Self { - let Self { x, y } = self; - Point::new(0 - x, y) - } - - // Map into prime-order subgroup of equivalent Montgomery curve - pub fn into_montcurve(self) -> MPoint { - if self.is_zero() { - MPoint::zero() - } else { - let Self { x, y } = self; - let x0 = (1 + y) / (1 - y); - let y0 = (1 + y) / (x * (1 - y)); - - MPoint::new(x0, y0) - } - } - } - - impl Eq for Point { - fn eq(self, p: Self) -> bool { - let Self { x: x1, y: y1 } = self; - let Self { x: x2, y: y2 } = p; - - (x1 == x2) & (y1 == y2) - } - } - - impl Curve { - // Curve constructor - pub fn new(a: Field, d: Field, gen: Point) -> Curve { - // Check curve coefficients - assert(a * d * (a - d) != 0); - - let curve = Curve { a, d, gen }; - - // gen should be on the curve - assert(curve.contains(curve.gen)); - - curve - } - - // Conversion to CurveGroup coordinates - pub fn into_group(self) -> curvegroup::Curve { - let Curve { a, d, gen } = self; - - curvegroup::Curve { a, d, gen: gen.into_group() } - } - - // Membership check - pub fn contains(self, p: Point) -> bool { - let Point { x, y } = p; - self.a * x * x + y * y == 1 + self.d * x * x * y * y - } - - // Point addition, implemented in terms of mixed addition for reasons of efficiency - pub fn add(self, p1: Point, p2: Point) -> Point { - self.mixed_add(p1, p2.into_group()).into_affine() - } - - // Mixed point addition, i.e. first argument in affine, second in CurveGroup coordinates. - pub fn mixed_add(self, p1: Point, p2: curvegroup::Point) -> curvegroup::Point { - let Point { x: x1, y: y1 } = p1; - let curvegroup::Point { x: x2, y: y2, t: t2, z: z2 } = p2; - - let a = x1 * x2; - let b = y1 * y2; - let c = self.d * x1 * y1 * t2; - let e = (x1 + y1) * (x2 + y2) - a - b; - let f = z2 - c; - let g = z2 + c; - let h = b - self.a * a; - - let x = e * f; - let y = g * h; - let t = e * h; - let z = f * g; - - curvegroup::Point::new(x, y, t, z) - } - - // Scalar multiplication with scalar represented by a bit array (little-endian convention). - // If k is the natural number represented by `bits`, then this computes p + ... + p k times. - pub fn bit_mul(self, bits: [u1; N], p: Point) -> Point { - self.into_group().bit_mul(bits, p.into_group()).into_affine() - } - - // Scalar multiplication (p + ... + p n times) - pub fn mul(self, n: Field, p: Point) -> Point { - self.into_group().mul(n, p.into_group()).into_affine() - } - - // Multi-scalar multiplication (n[0]*p[0] + ... + n[N]*p[N], where * denotes scalar multiplication) - pub fn msm(self, n: [Field; N], p: [Point; N]) -> Point { - let mut out = Point::zero(); - - for i in 0..N { - out = self.add(out, self.mul(n[i], p[i])); - } - - out - } - - // Point subtraction - pub fn subtract(self, p1: Point, p2: Point) -> Point { - self.add(p1, p2.negate()) - } - - // Conversion to equivalent Montgomery curve - pub fn into_montcurve(self) -> MCurve { - let j = 2 * (self.a + self.d) / (self.a - self.d); - let k = 4 / (self.a - self.d); - let gen_montcurve = self.gen.into_montcurve(); - - MCurve::new(j, k, gen_montcurve) - } - - // Conversion to equivalent Short Weierstrass curve - pub fn into_swcurve(self) -> SWCurve { - self.into_montcurve().into_swcurve() - } - - // Point mapping into equivalent Short Weierstrass curve - pub fn map_into_swcurve(self, p: Point) -> SWPoint { - self.into_montcurve().map_into_swcurve(p.into_montcurve()) - } - - // Point mapping from equivalent Short Weierstrass curve - pub fn map_from_swcurve(self, p: SWPoint) -> Point { - self.into_montcurve().map_from_swcurve(p).into_tecurve() - } - - // Elligator 2 map-to-curve method (via rational map) - pub fn elligator2_map(self, u: Field) -> Point { - self.into_montcurve().elligator2_map(u).into_tecurve() - } - - // Simplified SWU map-to-curve method (via rational map) - pub fn swu_map(self, z: Field, u: Field) -> Point { - self.into_montcurve().swu_map(z, u).into_tecurve() - } - } -} -pub mod curvegroup { - // CurveGroup coordinate representation of Twisted Edwards curves - // Points are represented by four-dimensional projective coordinates, viz. extended Twisted Edwards coordinates. - // See section 3 of for details. - use crate::cmp::Eq; - use crate::ec::montcurve::curvegroup::Curve as MCurve; - use crate::ec::montcurve::curvegroup::Point as MPoint; - use crate::ec::swcurve::curvegroup::Curve as SWCurve; - use crate::ec::swcurve::curvegroup::Point as SWPoint; - use crate::ec::tecurve::affine; - - // Curve specification - pub struct Curve { // Twisted Edwards curve - // Coefficients in defining equation a(x^2 + y^2)z^2 = z^4 + dx^2y^2 - pub a: Field, - pub d: Field, - // Generator as point in projective coordinates - pub gen: Point, - } - // Point in extended twisted Edwards coordinates - pub struct Point { - pub x: Field, - pub y: Field, - pub t: Field, - pub z: Field, - } - - impl Point { - // Point constructor - pub fn new(x: Field, y: Field, t: Field, z: Field) -> Self { - Self { x, y, t, z } - } - - // Check if zero - pub fn is_zero(self) -> bool { - let Self { x, y, t, z } = self; - (x == 0) & (y == z) & (y != 0) & (t == 0) - } - - // Conversion to affine coordinates - pub fn into_affine(self) -> affine::Point { - let Self { x, y, t: _t, z } = self; - - affine::Point::new(x / z, y / z) - } - - // Additive identity - pub fn zero() -> Self { - Point::new(0, 1, 0, 1) - } - - // Negation - pub fn negate(self) -> Self { - let Self { x, y, t, z } = self; - - Point::new(0 - x, y, 0 - t, z) - } - - // Map into prime-order subgroup of equivalent Montgomery curve - pub fn into_montcurve(self) -> MPoint { - self.into_affine().into_montcurve().into_group() - } - } - - impl Eq for Point { - fn eq(self, p: Self) -> bool { - let Self { x: x1, y: y1, t: _t1, z: z1 } = self; - let Self { x: x2, y: y2, t: _t2, z: z2 } = p; - - (x1 * z2 == x2 * z1) & (y1 * z2 == y2 * z1) - } - } - - impl Curve { - // Curve constructor - pub fn new(a: Field, d: Field, gen: Point) -> Curve { - // Check curve coefficients - assert(a * d * (a - d) != 0); - - let curve = Curve { a, d, gen }; - - // gen should be on the curve - assert(curve.contains(curve.gen)); - - curve - } - - // Conversion to affine coordinates - pub fn into_affine(self) -> affine::Curve { - let Curve { a, d, gen } = self; - - affine::Curve { a, d, gen: gen.into_affine() } - } - - // Membership check - pub fn contains(self, p: Point) -> bool { - let Point { x, y, t, z } = p; - - (z != 0) - & (z * t == x * y) - & (z * z * (self.a * x * x + y * y) == z * z * z * z + self.d * x * x * y * y) - } - - // Point addition - pub fn add(self, p1: Point, p2: Point) -> Point { - let Point { x: x1, y: y1, t: t1, z: z1 } = p1; - let Point { x: x2, y: y2, t: t2, z: z2 } = p2; - - let a = x1 * x2; - let b = y1 * y2; - let c = self.d * t1 * t2; - let d = z1 * z2; - let e = (x1 + y1) * (x2 + y2) - a - b; - let f = d - c; - let g = d + c; - let h = b - self.a * a; - - let x = e * f; - let y = g * h; - let t = e * h; - let z = f * g; - - Point::new(x, y, t, z) - } - - // Point doubling, cf. section 3.3 - pub fn double(self, p: Point) -> Point { - let Point { x, y, t: _t, z } = p; - - let a = x * x; - let b = y * y; - let c = 2 * z * z; - let d = self.a * a; - let e = (x + y) * (x + y) - a - b; - let g = d + b; - let f = g - c; - let h = d - b; - - let x0 = e * f; - let y0 = g * h; - let t0 = e * h; - let z0 = f * g; - - Point::new(x0, y0, t0, z0) - } - - // Scalar multiplication with scalar represented by a bit array (little-endian convention). - // If k is the natural number represented by `bits`, then this computes p + ... + p k times. - pub fn bit_mul(self, bits: [u1; N], p: Point) -> Point { - let mut out = Point::zero(); - - for i in 0..N { - out = self.add( - self.add(out, out), - if (bits[N - i - 1] == 0) { - Point::zero() - } else { - p - }, - ); - } - - out - } - - // Scalar multiplication (p + ... + p n times) - pub fn mul(self, n: Field, p: Point) -> Point { - // TODO: temporary workaround until issue 1354 is solved - let mut n_as_bits: [u1; 254] = [0; 254]; - let tmp: [u1; 254] = n.to_le_bits(); - for i in 0..254 { - n_as_bits[i] = tmp[i]; - } - - self.bit_mul(n_as_bits, p) - } - - // Multi-scalar multiplication (n[0]*p[0] + ... + n[N]*p[N], where * denotes scalar multiplication) - pub fn msm(self, n: [Field; N], p: [Point; N]) -> Point { - let mut out = Point::zero(); - - for i in 0..N { - out = self.add(out, self.mul(n[i], p[i])); - } - - out - } - - // Point subtraction - pub fn subtract(self, p1: Point, p2: Point) -> Point { - self.add(p1, p2.negate()) - } - - // Conversion to equivalent Montgomery curve - pub fn into_montcurve(self) -> MCurve { - self.into_affine().into_montcurve().into_group() - } - - // Conversion to equivalent Short Weierstrass curve - pub fn into_swcurve(self) -> SWCurve { - self.into_montcurve().into_swcurve() - } - - // Point mapping into equivalent short Weierstrass curve - pub fn map_into_swcurve(self, p: Point) -> SWPoint { - self.into_montcurve().map_into_swcurve(p.into_montcurve()) - } - - // Point mapping from equivalent short Weierstrass curve - pub fn map_from_swcurve(self, p: SWPoint) -> Point { - self.into_montcurve().map_from_swcurve(p).into_tecurve() - } - - // Elligator 2 map-to-curve method (via rational maps) - pub fn elligator2_map(self, u: Field) -> Point { - self.into_montcurve().elligator2_map(u).into_tecurve() - } - - // Simplified SWU map-to-curve method (via rational map) - pub fn swu_map(self, z: Field, u: Field) -> Point { - self.into_montcurve().swu_map(z, u).into_tecurve() - } - } -} diff --git a/noir_stdlib/src/hash/poseidon2.nr b/noir_stdlib/src/hash/poseidon2.nr index f2167c43c2c..419f07a2aca 100644 --- a/noir_stdlib/src/hash/poseidon2.nr +++ b/noir_stdlib/src/hash/poseidon2.nr @@ -13,11 +13,7 @@ pub struct Poseidon2 { impl Poseidon2 { #[no_predicates] pub fn hash(input: [Field; N], message_size: u32) -> Field { - if message_size == N { - Poseidon2::hash_internal(input, N, false) - } else { - Poseidon2::hash_internal(input, message_size, true) - } + Poseidon2::hash_internal(input, message_size, message_size != N) } pub(crate) fn new(iv: Field) -> Poseidon2 { diff --git a/noir_stdlib/src/lib.nr b/noir_stdlib/src/lib.nr index b1aabc48d3b..8e9dc13c13d 100644 --- a/noir_stdlib/src/lib.nr +++ b/noir_stdlib/src/lib.nr @@ -10,7 +10,6 @@ pub mod embedded_curve_ops; pub mod sha256; pub mod sha512; pub mod field; -pub mod ec; pub mod collections; pub mod compat; pub mod convert; diff --git a/test_programs/benchmarks/bench_eddsa_poseidon/Nargo.toml b/test_programs/benchmarks/bench_eddsa_poseidon/Nargo.toml index bc2a779f7b2..6c754f1d107 100644 --- a/test_programs/benchmarks/bench_eddsa_poseidon/Nargo.toml +++ b/test_programs/benchmarks/bench_eddsa_poseidon/Nargo.toml @@ -5,3 +5,4 @@ type = "bin" authors = [""] [dependencies] +ec = { tag = "v0.1.2", git = "https://github.com/noir-lang/ec" } diff --git a/test_programs/benchmarks/bench_eddsa_poseidon/src/main.nr b/test_programs/benchmarks/bench_eddsa_poseidon/src/main.nr index bcb0930cf24..c4a1d4b51f5 100644 --- a/test_programs/benchmarks/bench_eddsa_poseidon/src/main.nr +++ b/test_programs/benchmarks/bench_eddsa_poseidon/src/main.nr @@ -1,9 +1,11 @@ use std::default::Default; -use std::ec::consts::te::baby_jubjub; -use std::ec::tecurve::affine::Point as TEPoint; use std::hash::Hasher; use std::hash::poseidon::PoseidonHasher; +use ec::consts::te::baby_jubjub; +use ec::tecurve::affine::Point as TEPoint; + + fn main( msg: pub Field, pub_key_x: Field, diff --git a/test_programs/compile_success_empty/ec_baby_jubjub/Nargo.toml b/test_programs/compile_success_empty/ec_baby_jubjub/Nargo.toml deleted file mode 100644 index fdb0df17112..00000000000 --- a/test_programs/compile_success_empty/ec_baby_jubjub/Nargo.toml +++ /dev/null @@ -1,7 +0,0 @@ -[package] -name = "ec_baby_jubjub" -description = "Baby Jubjub sanity checks" -type = "bin" -authors = [""] - -[dependencies] diff --git a/test_programs/compile_success_empty/ec_baby_jubjub/src/main.nr b/test_programs/compile_success_empty/ec_baby_jubjub/src/main.nr deleted file mode 100644 index caaa51d84f0..00000000000 --- a/test_programs/compile_success_empty/ec_baby_jubjub/src/main.nr +++ /dev/null @@ -1,210 +0,0 @@ -// Tests may be checked against https://github.com/cfrg/draft-irtf-cfrg-hash-to-curve/tree/main/poc -use std::ec::tecurve::affine::Curve as AffineCurve; -use std::ec::tecurve::affine::Point as Gaffine; -use std::ec::tecurve::curvegroup::Point as G; - -use std::ec::swcurve::affine::Point as SWGaffine; -use std::ec::swcurve::curvegroup::Point as SWG; - -use std::compat; -use std::ec::montcurve::affine::Point as MGaffine; -use std::ec::montcurve::curvegroup::Point as MG; - -fn main() { - // This test only makes sense if Field is the right prime field. - if compat::is_bn254() { - // Define Baby Jubjub (ERC-2494) parameters in affine representation - let bjj_affine = AffineCurve::new( - 168700, - 168696, - Gaffine::new( - 995203441582195749578291179787384436505546430278305826713579947235728471134, - 5472060717959818805561601436314318772137091100104008585924551046643952123905, - ), - ); - // Test addition - let p1_affine = Gaffine::new( - 17777552123799933955779906779655732241715742912184938656739573121738514868268, - 2626589144620713026669568689430873010625803728049924121243784502389097019475, - ); - let p2_affine = Gaffine::new( - 16540640123574156134436876038791482806971768689494387082833631921987005038935, - 20819045374670962167435360035096875258406992893633759881276124905556507972311, - ); - - let p3_affine = bjj_affine.add(p1_affine, p2_affine); - assert(p3_affine.eq(Gaffine::new( - 7916061937171219682591368294088513039687205273691143098332585753343424131937, - 14035240266687799601661095864649209771790948434046947201833777492504781204499, - ))); - // Test scalar multiplication - let p4_affine = bjj_affine.mul(2, p1_affine); - assert(p4_affine.eq(Gaffine::new( - 6890855772600357754907169075114257697580319025794532037257385534741338397365, - 4338620300185947561074059802482547481416142213883829469920100239455078257889, - ))); - assert(p4_affine.eq(bjj_affine.bit_mul([0, 1], p1_affine))); - // Test subtraction - let p5_affine = bjj_affine.subtract(p3_affine, p3_affine); - assert(p5_affine.eq(Gaffine::zero())); - // Check that these points are on the curve - assert( - bjj_affine.contains(bjj_affine.gen) - & bjj_affine.contains(p1_affine) - & bjj_affine.contains(p2_affine) - & bjj_affine.contains(p3_affine) - & bjj_affine.contains(p4_affine) - & bjj_affine.contains(p5_affine), - ); - // Test CurveGroup equivalents - let bjj = bjj_affine.into_group(); // Baby Jubjub - let p1 = p1_affine.into_group(); - let p2 = p2_affine.into_group(); - let p3 = p3_affine.into_group(); - let p4 = p4_affine.into_group(); - let p5 = p5_affine.into_group(); - // Test addition - assert(p3.eq(bjj.add(p1, p2))); - // Test scalar multiplication - assert(p4.eq(bjj.mul(2, p1))); - assert(p4.eq(bjj.bit_mul([0, 1], p1))); - // Test subtraction - assert(G::zero().eq(bjj.subtract(p3, p3))); - assert(p5.eq(G::zero())); - // Check that these points are on the curve - assert( - bjj.contains(bjj.gen) - & bjj.contains(p1) - & bjj.contains(p2) - & bjj.contains(p3) - & bjj.contains(p4) - & bjj.contains(p5), - ); - // Test SWCurve equivalents of the above - // First the affine representation - let bjj_swcurve_affine = bjj_affine.into_swcurve(); - - let p1_swcurve_affine = bjj_affine.map_into_swcurve(p1_affine); - let p2_swcurve_affine = bjj_affine.map_into_swcurve(p2_affine); - let p3_swcurve_affine = bjj_affine.map_into_swcurve(p3_affine); - let p4_swcurve_affine = bjj_affine.map_into_swcurve(p4_affine); - let p5_swcurve_affine = bjj_affine.map_into_swcurve(p5_affine); - // Addition - assert(p3_swcurve_affine.eq(bjj_swcurve_affine.add(p1_swcurve_affine, p2_swcurve_affine))); - // Doubling - assert(p4_swcurve_affine.eq(bjj_swcurve_affine.mul(2, p1_swcurve_affine))); - assert(p4_swcurve_affine.eq(bjj_swcurve_affine.bit_mul([0, 1], p1_swcurve_affine))); - // Subtraction - assert(SWGaffine::zero().eq(bjj_swcurve_affine.subtract( - p3_swcurve_affine, - p3_swcurve_affine, - ))); - assert(p5_swcurve_affine.eq(SWGaffine::zero())); - // Check that these points are on the curve - assert( - bjj_swcurve_affine.contains(bjj_swcurve_affine.gen) - & bjj_swcurve_affine.contains(p1_swcurve_affine) - & bjj_swcurve_affine.contains(p2_swcurve_affine) - & bjj_swcurve_affine.contains(p3_swcurve_affine) - & bjj_swcurve_affine.contains(p4_swcurve_affine) - & bjj_swcurve_affine.contains(p5_swcurve_affine), - ); - // Then the CurveGroup representation - let bjj_swcurve = bjj.into_swcurve(); - - let p1_swcurve = bjj.map_into_swcurve(p1); - let p2_swcurve = bjj.map_into_swcurve(p2); - let p3_swcurve = bjj.map_into_swcurve(p3); - let p4_swcurve = bjj.map_into_swcurve(p4); - let p5_swcurve = bjj.map_into_swcurve(p5); - // Addition - assert(p3_swcurve.eq(bjj_swcurve.add(p1_swcurve, p2_swcurve))); - // Doubling - assert(p4_swcurve.eq(bjj_swcurve.mul(2, p1_swcurve))); - assert(p4_swcurve.eq(bjj_swcurve.bit_mul([0, 1], p1_swcurve))); - // Subtraction - assert(SWG::zero().eq(bjj_swcurve.subtract(p3_swcurve, p3_swcurve))); - assert(p5_swcurve.eq(SWG::zero())); - // Check that these points are on the curve - assert( - bjj_swcurve.contains(bjj_swcurve.gen) - & bjj_swcurve.contains(p1_swcurve) - & bjj_swcurve.contains(p2_swcurve) - & bjj_swcurve.contains(p3_swcurve) - & bjj_swcurve.contains(p4_swcurve) - & bjj_swcurve.contains(p5_swcurve), - ); - // Test MontCurve conversions - // First the affine representation - let bjj_montcurve_affine = bjj_affine.into_montcurve(); - - let p1_montcurve_affine = p1_affine.into_montcurve(); - let p2_montcurve_affine = p2_affine.into_montcurve(); - let p3_montcurve_affine = p3_affine.into_montcurve(); - let p4_montcurve_affine = p4_affine.into_montcurve(); - let p5_montcurve_affine = p5_affine.into_montcurve(); - // Addition - assert(p3_montcurve_affine.eq(bjj_montcurve_affine.add( - p1_montcurve_affine, - p2_montcurve_affine, - ))); - // Doubling - assert(p4_montcurve_affine.eq(bjj_montcurve_affine.mul(2, p1_montcurve_affine))); - assert(p4_montcurve_affine.eq(bjj_montcurve_affine.bit_mul([0, 1], p1_montcurve_affine))); - // Subtraction - assert(MGaffine::zero().eq(bjj_montcurve_affine.subtract( - p3_montcurve_affine, - p3_montcurve_affine, - ))); - assert(p5_montcurve_affine.eq(MGaffine::zero())); - // Check that these points are on the curve - assert( - bjj_montcurve_affine.contains(bjj_montcurve_affine.gen) - & bjj_montcurve_affine.contains(p1_montcurve_affine) - & bjj_montcurve_affine.contains(p2_montcurve_affine) - & bjj_montcurve_affine.contains(p3_montcurve_affine) - & bjj_montcurve_affine.contains(p4_montcurve_affine) - & bjj_montcurve_affine.contains(p5_montcurve_affine), - ); - // Then the CurveGroup representation - let bjj_montcurve = bjj.into_montcurve(); - - let p1_montcurve = p1_montcurve_affine.into_group(); - let p2_montcurve = p2_montcurve_affine.into_group(); - let p3_montcurve = p3_montcurve_affine.into_group(); - let p4_montcurve = p4_montcurve_affine.into_group(); - let p5_montcurve = p5_montcurve_affine.into_group(); - // Addition - assert(p3_montcurve.eq(bjj_montcurve.add(p1_montcurve, p2_montcurve))); - // Doubling - assert(p4_montcurve.eq(bjj_montcurve.mul(2, p1_montcurve))); - assert(p4_montcurve.eq(bjj_montcurve.bit_mul([0, 1], p1_montcurve))); - // Subtraction - assert(MG::zero().eq(bjj_montcurve.subtract(p3_montcurve, p3_montcurve))); - assert(p5_montcurve.eq(MG::zero())); - // Check that these points are on the curve - assert( - bjj_montcurve.contains(bjj_montcurve.gen) - & bjj_montcurve.contains(p1_montcurve) - & bjj_montcurve.contains(p2_montcurve) - & bjj_montcurve.contains(p3_montcurve) - & bjj_montcurve.contains(p4_montcurve) - & bjj_montcurve.contains(p5_montcurve), - ); - // Elligator 2 map-to-curve - let ell2_pt_map = bjj_affine.elligator2_map(27); - - assert(ell2_pt_map.eq(MGaffine::new( - 7972459279704486422145701269802978968072470631857513331988813812334797879121, - 8142420778878030219043334189293412482212146646099536952861607542822144507872, - ) - .into_tecurve())); - // SWU map-to-curve - let swu_pt_map = bjj_affine.swu_map(5, 27); - - assert(swu_pt_map.eq(bjj_affine.map_from_swcurve(SWGaffine::new( - 2162719247815120009132293839392097468339661471129795280520343931405114293888, - 5341392251743377373758788728206293080122949448990104760111875914082289313973, - )))); - } -} diff --git a/test_programs/compile_success_empty/regression_2099/Nargo.toml b/test_programs/compile_success_empty/regression_2099/Nargo.toml index 6b9f9a24038..69fd4caabed 100644 --- a/test_programs/compile_success_empty/regression_2099/Nargo.toml +++ b/test_programs/compile_success_empty/regression_2099/Nargo.toml @@ -2,4 +2,6 @@ name = "regression_2099" type = "bin" authors = [""] + [dependencies] +ec = { tag = "v0.1.2", git = "https://github.com/noir-lang/ec" } diff --git a/test_programs/compile_success_empty/regression_2099/src/main.nr b/test_programs/compile_success_empty/regression_2099/src/main.nr index 3fe3cdaf39a..3a8b9092792 100644 --- a/test_programs/compile_success_empty/regression_2099/src/main.nr +++ b/test_programs/compile_success_empty/regression_2099/src/main.nr @@ -1,5 +1,5 @@ -use std::ec::tecurve::affine::Curve as AffineCurve; -use std::ec::tecurve::affine::Point as Gaffine; +use ec::tecurve::affine::Curve as AffineCurve; +use ec::tecurve::affine::Point as Gaffine; fn main() { // Define Baby Jubjub (ERC-2494) parameters in affine representation diff --git a/test_programs/execution_success/eddsa/Prover.toml b/test_programs/execution_success/eddsa/Prover.toml deleted file mode 100644 index 53555202ca6..00000000000 --- a/test_programs/execution_success/eddsa/Prover.toml +++ /dev/null @@ -1,3 +0,0 @@ -_priv_key_a = 123 -_priv_key_b = 456 -msg = 789 diff --git a/test_programs/execution_success/eddsa/src/main.nr b/test_programs/execution_success/eddsa/src/main.nr deleted file mode 100644 index 83e6c6565a9..00000000000 --- a/test_programs/execution_success/eddsa/src/main.nr +++ /dev/null @@ -1,128 +0,0 @@ -use std::compat; -use std::ec::consts::te::baby_jubjub; -use std::ec::tecurve::affine::Point as TEPoint; -use std::hash::Hasher; -use std::hash::poseidon::PoseidonHasher; -use std::hash::poseidon2::Poseidon2Hasher; - -fn main(msg: pub Field, _priv_key_a: Field, _priv_key_b: Field) { - // Skip this test for non-bn254 backends - if compat::is_bn254() { - let bjj = baby_jubjub(); - - let pub_key_a = bjj.curve.mul(_priv_key_a, bjj.curve.gen); - let pub_key_b = bjj.curve.mul(_priv_key_b, bjj.curve.gen); - let (pub_key_a_x, pub_key_a_y) = eddsa_to_pub(_priv_key_a); - let (pub_key_b_x, pub_key_b_y) = eddsa_to_pub(_priv_key_b); - assert(TEPoint::new(pub_key_a_x, pub_key_a_y) == pub_key_a); - assert(TEPoint::new(pub_key_b_x, pub_key_b_y) == pub_key_b); - // Manually computed as fields can't use modulo. Importantantly the commitment is within - // the subgroup order. Note that choice of hash is flexible for this step. - // let r_a = hash::pedersen_commitment([_priv_key_a, msg])[0] % bjj.suborder; // modulus computed manually - let r_a = 1414770703199880747815475415092878800081323795074043628810774576767372531818; - // let r_b = hash::pedersen_commitment([_priv_key_b, msg])[0] % bjj.suborder; // modulus computed manually - let r_b = 571799555715456644614141527517766533395606396271089506978608487688924659618; - - let r8_a = bjj.curve.mul(r_a, bjj.base8); - let r8_b = bjj.curve.mul(r_b, bjj.base8); - // let h_a: [Field; 6] = hash::poseidon::bn254::hash_5([ - // r8_a.x, - // r8_a.y, - // pub_key_a.x, - // pub_key_a.y, - // msg, - // ]); - // let h_b: [Field; 6] = hash::poseidon::bn254::hash_5([ - // r8_b.x, - // r8_b.y, - // pub_key_b.x, - // pub_key_b.y, - // msg, - // ]); - // let s_a = (r_a + _priv_key_a * h_a) % bjj.suborder; // modulus computed manually - let s_a = 30333430637424319196043722294837632681219980330991241982145549329256671548; - // let s_b = (r_b + _priv_key_b * h_b) % bjj.suborder; // modulus computed manually - let s_b = 1646085314320208098241070054368798527940102577261034947654839408482102287019; - // User A verifies their signature over the message - assert(eddsa_poseidon_verify(pub_key_a.x, pub_key_a.y, s_a, r8_a.x, r8_a.y, msg)); - // User B's signature over the message can't be used with user A's pub key - assert(!eddsa_poseidon_verify(pub_key_a.x, pub_key_a.y, s_b, r8_b.x, r8_b.y, msg)); - // User A's signature over the message can't be used with another message - assert(!eddsa_poseidon_verify(pub_key_a.x, pub_key_a.y, s_a, r8_a.x, r8_a.y, msg + 1)); - // Using a different hash should fail - assert( - !eddsa_verify::(pub_key_a.x, pub_key_a.y, s_a, r8_a.x, r8_a.y, msg), - ); - } -} - -// Returns true if signature is valid -pub fn eddsa_poseidon_verify( - pub_key_x: Field, - pub_key_y: Field, - signature_s: Field, - signature_r8_x: Field, - signature_r8_y: Field, - message: Field, -) -> bool { - eddsa_verify::( - pub_key_x, - pub_key_y, - signature_s, - signature_r8_x, - signature_r8_y, - message, - ) -} - -pub fn eddsa_verify( - pub_key_x: Field, - pub_key_y: Field, - signature_s: Field, - signature_r8_x: Field, - signature_r8_y: Field, - message: Field, -) -> bool -where - H: Hasher + Default, -{ - // Verifies by testing: - // S * B8 = R8 + H(R8, A, m) * A8 - let bjj = baby_jubjub(); - - let pub_key = TEPoint::new(pub_key_x, pub_key_y); - assert(bjj.curve.contains(pub_key)); - - let signature_r8 = TEPoint::new(signature_r8_x, signature_r8_y); - assert(bjj.curve.contains(signature_r8)); - // Ensure S < Subgroup Order - assert(signature_s.lt(bjj.suborder)); - // Calculate the h = H(R, A, msg) - let mut hasher = H::default(); - hasher.write(signature_r8_x); - hasher.write(signature_r8_y); - hasher.write(pub_key_x); - hasher.write(pub_key_y); - hasher.write(message); - let hash: Field = hasher.finish(); - // Calculate second part of the right side: right2 = h*8*A - // Multiply by 8 by doubling 3 times. This also ensures that the result is in the subgroup. - let pub_key_mul_2 = bjj.curve.add(pub_key, pub_key); - let pub_key_mul_4 = bjj.curve.add(pub_key_mul_2, pub_key_mul_2); - let pub_key_mul_8 = bjj.curve.add(pub_key_mul_4, pub_key_mul_4); - // We check that A8 is not zero. - assert(!pub_key_mul_8.is_zero()); - // Compute the right side: R8 + h * A8 - let right = bjj.curve.add(signature_r8, bjj.curve.mul(hash, pub_key_mul_8)); - // Calculate left side of equation left = S * B8 - let left = bjj.curve.mul(signature_s, bjj.base8); - - left.eq(right) -} - -// Returns the public key of the given secret key as (pub_key_x, pub_key_y) -pub fn eddsa_to_pub(secret: Field) -> (Field, Field) { - let bjj = baby_jubjub(); - let pub_key = bjj.curve.mul(secret, bjj.curve.gen); - (pub_key.x, pub_key.y) -} diff --git a/test_programs/execution_success/inline_decompose_hint_brillig_call/Nargo.toml b/test_programs/execution_success/inline_decompose_hint_brillig_call/Nargo.toml new file mode 100644 index 00000000000..ecac2dfb197 --- /dev/null +++ b/test_programs/execution_success/inline_decompose_hint_brillig_call/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "inline_decompose_hint_brillig_call" +version = "0.1.0" +type = "bin" +authors = [""] + +[dependencies] diff --git a/test_programs/execution_success/inline_decompose_hint_brillig_call/src/main.nr b/test_programs/execution_success/inline_decompose_hint_brillig_call/src/main.nr new file mode 100644 index 00000000000..e500f0f976d --- /dev/null +++ b/test_programs/execution_success/inline_decompose_hint_brillig_call/src/main.nr @@ -0,0 +1,15 @@ +use std::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar, fixed_base_scalar_mul}; + +fn main() -> pub Field { + let pre_address = 0x23d95e303879a5d0bbef78ecbc335e559da37431f6dcd11da54ed375c2846813; + let (a, b) = std::field::bn254::decompose(pre_address); + let curve = EmbeddedCurveScalar { lo: a, hi: b }; + let key = fixed_base_scalar_mul(curve); + let point = EmbeddedCurvePoint { + x: 0x111223493147f6785514b1c195bb37a2589f22a6596d30bb2bb145fdc9ca8f1e, + y: 0x273bbffd678edce8fe30e0deafc4f66d58357c06fd4a820285294b9746c3be95, + is_infinite: false, + }; + let address_point = key.add(point); + address_point.x +} diff --git a/test_programs/execution_success/loop_invariant_regression/src/main.nr b/test_programs/execution_success/loop_invariant_regression/src/main.nr index 25f6e92f868..c28ce063116 100644 --- a/test_programs/execution_success/loop_invariant_regression/src/main.nr +++ b/test_programs/execution_success/loop_invariant_regression/src/main.nr @@ -2,6 +2,7 @@ // to be hoisted to the loop's pre-header block. fn main(x: u32, y: u32) { loop(4, x, y); + array_read_loop(4, x); } fn loop(upper_bound: u32, x: u32, y: u32) { @@ -11,3 +12,15 @@ fn loop(upper_bound: u32, x: u32, y: u32) { assert_eq(z, 12); } } + +fn array_read_loop(upper_bound: u32, x: u32) { + let arr = [2; 5]; + for i in 0..upper_bound { + for j in 0..upper_bound { + for _ in 0..upper_bound { + assert_eq(arr[i], x); + assert_eq(arr[j], x); + } + } + } +} diff --git a/test_programs/execution_success/eddsa/Nargo.toml b/test_programs/execution_success/negated_jmpif_condition/Nargo.toml similarity index 51% rename from test_programs/execution_success/eddsa/Nargo.toml rename to test_programs/execution_success/negated_jmpif_condition/Nargo.toml index 0f545c2febc..c83e2c1c1fd 100644 --- a/test_programs/execution_success/eddsa/Nargo.toml +++ b/test_programs/execution_success/negated_jmpif_condition/Nargo.toml @@ -1,6 +1,5 @@ [package] -name = "eddsa" -description = "Eddsa verification" +name = "negated_jmpif_condition" type = "bin" authors = [""] diff --git a/test_programs/execution_success/negated_jmpif_condition/Prover.toml b/test_programs/execution_success/negated_jmpif_condition/Prover.toml new file mode 100644 index 00000000000..151faa5a9b1 --- /dev/null +++ b/test_programs/execution_success/negated_jmpif_condition/Prover.toml @@ -0,0 +1 @@ +x = "2" \ No newline at end of file diff --git a/test_programs/execution_success/negated_jmpif_condition/src/main.nr b/test_programs/execution_success/negated_jmpif_condition/src/main.nr new file mode 100644 index 00000000000..06de2b41820 --- /dev/null +++ b/test_programs/execution_success/negated_jmpif_condition/src/main.nr @@ -0,0 +1,9 @@ +fn main(mut x: Field) { + let mut q = 0; + + if x != 10 { + q = 2; + } + + assert(q == 2); +} diff --git a/test_programs/execution_success/regression/Prover.toml b/test_programs/execution_success/regression/Prover.toml index 2875190982f..c81cbf10fbb 100644 --- a/test_programs/execution_success/regression/Prover.toml +++ b/test_programs/execution_success/regression/Prover.toml @@ -1,2 +1,4 @@ x = [0x3f, 0x1c, 0xb8, 0x99, 0xab] z = 3 +u = "169" +v = "-13" \ No newline at end of file diff --git a/test_programs/execution_success/regression/src/main.nr b/test_programs/execution_success/regression/src/main.nr index 1c2f557d2cd..809fdbe4b28 100644 --- a/test_programs/execution_success/regression/src/main.nr +++ b/test_programs/execution_success/regression/src/main.nr @@ -94,7 +94,7 @@ fn bitshift_variable(idx: u8) -> u64 { bits } -fn main(x: [u8; 5], z: Field) { +fn main(x: [u8; 5], z: Field, u: i16, v: i16) { //Issue 1144 let (nib, len) = compact_decode(x, z); assert(len == 5); @@ -130,4 +130,12 @@ fn main(x: [u8; 5], z: Field) { assert(result_0 == 1); let result_4 = bitshift_variable(4); assert(result_4 == 16); + + // Issue 6609 + assert(u % -13 == 0); + assert(u % v == 0); + assert(u % -11 == 4); + assert(-u % -11 == -4); + assert(u % -11 == u % (v + 2)); + assert(-u % -11 == -u % (v + 2)); } diff --git a/test_programs/memory_report.sh b/test_programs/memory_report.sh new file mode 100755 index 00000000000..1b8274b76cc --- /dev/null +++ b/test_programs/memory_report.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash +set -e + +sudo apt-get install heaptrack + +NARGO="nargo" + + +# Tests to be profiled for memory report +tests_to_profile=("keccak256" "workspace" "regression_4709" "ram_blowup_regression") + +current_dir=$(pwd) +execution_success_path="$current_dir/execution_success" +test_dirs=$(ls $execution_success_path) + +FIRST="1" + +echo "{\"memory_reports\": [ " > memory_report.json + + +for test_name in ${tests_to_profile[@]}; do + full_path=$execution_success_path"/"$test_name + cd $full_path + + if [ $FIRST = "1" ] + then + FIRST="0" + else + echo " ," >> $current_dir"/memory_report.json" + fi + heaptrack --output $current_dir/$test_name"_heap" $NARGO compile --force + if test -f $current_dir/$test_name"_heap.gz"; + then + heaptrack --analyze $current_dir/$test_name"_heap.gz" > $current_dir/$test_name"_heap_analysis.txt" + rm $current_dir/$test_name"_heap.gz" + else + heaptrack --analyze $current_dir/$test_name"_heap.zst" > $current_dir/$test_name"_heap_analysis.txt" + rm $current_dir/$test_name"_heap.zst" + fi + consumption="$(grep 'peak heap memory consumption' $current_dir/$test_name'_heap_analysis.txt')" + len=${#consumption}-30 + peak=${consumption:30:len} + rm $current_dir/$test_name"_heap_analysis.txt" + echo -e " {\n \"artifact_name\":\"$test_name\",\n \"peak_memory\":\"$peak\"\n }" >> $current_dir"/memory_report.json" +done + +echo "]}" >> $current_dir"/memory_report.json" + diff --git a/test_programs/noir_test_success/comptime_blackbox/Nargo.toml b/test_programs/noir_test_success/comptime_blackbox/Nargo.toml new file mode 100644 index 00000000000..5eac6f3c91a --- /dev/null +++ b/test_programs/noir_test_success/comptime_blackbox/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "comptime_blackbox" +type = "bin" +authors = [""] +compiler_version = ">=0.27.0" + +[dependencies] diff --git a/test_programs/noir_test_success/comptime_blackbox/src/main.nr b/test_programs/noir_test_success/comptime_blackbox/src/main.nr new file mode 100644 index 00000000000..c3784e73b09 --- /dev/null +++ b/test_programs/noir_test_success/comptime_blackbox/src/main.nr @@ -0,0 +1,155 @@ +//! Tests to show that the comptime interpreter implement blackbox functions. +use std::bigint; +use std::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar, multi_scalar_mul}; + +/// Test that all bigint operations work in comptime. +#[test] +fn test_bigint() { + let result: [u8] = comptime { + let a = bigint::Secpk1Fq::from_le_bytes(&[0, 1, 2, 3, 4]); + let b = bigint::Secpk1Fq::from_le_bytes(&[5, 6, 7, 8, 9]); + let c = (a + b) * b / a - a; + c.to_le_bytes() + }; + // Do the same calculation outside comptime. + let a = bigint::Secpk1Fq::from_le_bytes(&[0, 1, 2, 3, 4]); + let b = bigint::Secpk1Fq::from_le_bytes(&[5, 6, 7, 8, 9]); + let c = bigint::Secpk1Fq::from_le_bytes(result); + assert_eq(c, (a + b) * b / a - a); +} + +/// Test that to_le_radix returns an array. +#[test] +fn test_to_le_radix() { + comptime { + let field = 2; + let bytes: [u8; 8] = field.to_le_radix(256); + let _num = bigint::BigInt::from_le_bytes(bytes, bigint::bn254_fq); + }; +} + +#[test] +fn test_bitshift() { + let c = comptime { + let a: i32 = 10; + let b: u32 = 4; + a << b as u8 + }; + assert_eq(c, 160); +} + +#[test] +fn test_aes128_encrypt() { + let ciphertext = comptime { + let plaintext: [u8; 5] = [1, 2, 3, 4, 5]; + let iv: [u8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]; + let key: [u8; 16] = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]; + std::aes128::aes128_encrypt(plaintext, iv, key) + }; + let clear_len = 5; + let cipher_len = clear_len + 16 - clear_len % 16; + assert_eq(ciphertext.len(), cipher_len); +} + +#[test] +fn test_blake2s() { + let hash = comptime { + let input = [104, 101, 108, 108, 111]; + std::hash::blake2s(input) + }; + assert_eq(hash[0], 0x19); + assert_eq(hash[31], 0x25); +} + +#[test] +fn test_blake3() { + let hash = comptime { + let input = [104, 101, 108, 108, 111]; + std::hash::blake3(input) + }; + assert_eq(hash[0], 0xea); + assert_eq(hash[31], 0x0f); +} + +/// Test that ecdsa_secp256k1 is implemented. +#[test] +fn test_ecdsa_secp256k1() { + let (valid_array, valid_slice) = comptime { + let pub_key_x: [u8; 32] = hex_to_bytes("a0434d9e47f3c86235477c7b1ae6ae5d3442d49b1943c2b752a68e2a47e247c7").as_array(); + let pub_key_y: [u8; 32] = hex_to_bytes("893aba425419bc27a3b6c7e693a24c696f794c2ed877a1593cbee53b037368d7").as_array(); + let signature: [u8; 64] = hex_to_bytes("e5081c80ab427dc370346f4a0e31aa2bad8d9798c38061db9ae55a4e8df454fd28119894344e71b78770cc931d61f480ecbb0b89d6eb69690161e49a715fcd55").as_array(); + let hashed_message: [u8; 32] = hex_to_bytes("3a73f4123a5cd2121f21cd7e8d358835476949d035d9c2da6806b4633ac8c1e2").as_array(); + + let valid_array = std::ecdsa_secp256k1::verify_signature(pub_key_x, pub_key_y, signature, hashed_message); + let valid_slice = std::ecdsa_secp256k1::verify_signature_slice(pub_key_x, pub_key_y, signature, hashed_message.as_slice()); + + (valid_array, valid_slice) + }; + assert(valid_array); + assert(valid_slice); +} + +/// Test that ecdsa_secp256r1 is implemented. +#[test] +fn test_ecdsa_secp256r1() { + let (valid_array, valid_slice) = comptime { + let pub_key_x: [u8; 32] = hex_to_bytes("550f471003f3df97c3df506ac797f6721fb1a1fb7b8f6f83d224498a65c88e24").as_array(); + let pub_key_y: [u8; 32] = hex_to_bytes("136093d7012e509a73715cbd0b00a3cc0ff4b5c01b3ffa196ab1fb327036b8e6").as_array(); + let signature: [u8; 64] = hex_to_bytes("2c70a8d084b62bfc5ce03641caf9f72ad4da8c81bfe6ec9487bb5e1bef62a13218ad9ee29eaf351fdc50f1520c425e9b908a07278b43b0ec7b872778c14e0784").as_array(); + let hashed_message: [u8; 32] = hex_to_bytes("54705ba3baafdbdfba8c5f9a70f7a89bee98d906b53e31074da7baecdc0da9ad").as_array(); + + let valid_array = std::ecdsa_secp256r1::verify_signature(pub_key_x, pub_key_y, signature, hashed_message); + let valid_slice = std::ecdsa_secp256r1::verify_signature_slice(pub_key_x, pub_key_y, signature, hashed_message.as_slice()); + (valid_array, valid_slice) + }; + assert(valid_array); + assert(valid_slice); +} + +/// Test that sha256_compression is implemented. +#[test] +fn test_sha256() { + let hash = comptime { + let input: [u8; 1] = [0xbd]; + std::hash::sha256(input) + }; + assert_eq(hash[0], 0x68); + assert_eq(hash[31], 0x2b); +} + +/// Test that `embedded_curve_add` and `multi_scalar_mul` are implemented. +#[test] +fn test_embedded_curve_ops() { + let (sum, mul) = comptime { + let g1 = EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false }; + let s1 = EmbeddedCurveScalar { lo: 1, hi: 0 }; + let sum = g1 + g1; + let mul = multi_scalar_mul([g1, g1], [s1, s1]); + (sum, mul) + }; + assert_eq(sum, mul); +} + +/// Parse a lowercase hexadecimal string (without 0x prefix) as byte slice. +comptime fn hex_to_bytes(s: str) -> [u8] { + assert(N % 2 == 0); + let mut out = &[]; + let bz = s.as_bytes(); + let mut h: u32 = 0; + for i in 0 .. bz.len() { + let ascii = bz[i]; + let d = if ascii < 58 { + ascii - 48 + } else { + assert(ascii >= 97); // enforce >= 'a' + assert(ascii <= 102); // enforce <= 'f' + ascii - 87 + }; + h = h * 16 + d as u32; + if i % 2 == 1 { + out = out.push_back(h as u8); + h = 0; + } + } + out +} diff --git a/tooling/acvm_cli/src/cli/execute_cmd.rs b/tooling/acvm_cli/src/cli/execute_cmd.rs index c453936568c..bf5969718e5 100644 --- a/tooling/acvm_cli/src/cli/execute_cmd.rs +++ b/tooling/acvm_cli/src/cli/execute_cmd.rs @@ -8,7 +8,7 @@ use clap::Args; use crate::cli::fs::inputs::{read_bytecode_from_file, read_inputs_from_file}; use crate::errors::CliError; -use nargo::ops::{execute_program, DefaultForeignCallExecutor}; +use nargo::{foreign_calls::DefaultForeignCallExecutor, ops::execute_program}; use super::fs::witness::{create_output_witness_string, save_witness_to_dir}; diff --git a/tooling/debugger/src/foreign_calls.rs b/tooling/debugger/src/foreign_calls.rs index 6a773a4b348..ecf27a22f29 100644 --- a/tooling/debugger/src/foreign_calls.rs +++ b/tooling/debugger/src/foreign_calls.rs @@ -3,7 +3,7 @@ use acvm::{ pwg::ForeignCallWaitInfo, AcirField, FieldElement, }; -use nargo::ops::{DefaultForeignCallExecutor, ForeignCallExecutor}; +use nargo::foreign_calls::{DefaultForeignCallExecutor, ForeignCallExecutor}; use noirc_artifacts::debug::{DebugArtifact, DebugVars, StackFrame}; use noirc_errors::debug_info::{DebugFnId, DebugVarId}; use noirc_printable_type::ForeignCallError; diff --git a/tooling/debugger/tests/debug.rs b/tooling/debugger/tests/debug.rs index 2dca6b95f0e..eb43cf9cc6d 100644 --- a/tooling/debugger/tests/debug.rs +++ b/tooling/debugger/tests/debug.rs @@ -12,7 +12,7 @@ mod tests { let nargo_bin = cargo_bin("nargo").into_os_string().into_string().expect("Cannot parse nargo path"); - let timeout_seconds = 25; + let timeout_seconds = 30; let mut dbg_session = spawn_bash(Some(timeout_seconds * 1000)).expect("Could not start bash session"); diff --git a/tooling/lsp/src/requests/code_action/import_or_qualify.rs b/tooling/lsp/src/requests/code_action/import_or_qualify.rs index ffc83b05a5b..609a81bdfe7 100644 --- a/tooling/lsp/src/requests/code_action/import_or_qualify.rs +++ b/tooling/lsp/src/requests/code_action/import_or_qualify.rs @@ -183,6 +183,31 @@ mod foo { } } +fn foo(x: SomeTypeInBar) {}"#; + + assert_code_action(title, src, expected).await; + } + + #[test] + async fn test_import_code_action_for_struct_at_beginning_of_name() { + let title = "Import foo::bar::SomeTypeInBar"; + + let src = r#"mod foo { + pub mod bar { + pub struct SomeTypeInBar {} + } +} + +fn foo(x: >| NodeFinder<'a> { let struct_id = get_type_struct_id(typ); let is_primitive = typ.is_primitive(); + let has_self_param = matches!(function_kind, FunctionKind::SelfType(..)); for (name, methods) in methods_by_name { - for (func_id, method_type) in methods.iter() { - if function_kind == FunctionKind::Any { - if let Some(method_type) = method_type { - if method_type.unify(typ).is_err() { - continue; - } - } - } - - if let Some(struct_id) = struct_id { - let modifiers = self.interner.function_modifiers(&func_id); - let visibility = modifiers.visibility; - if !struct_member_is_visible( - struct_id, - visibility, - self.module_id, - self.def_maps, - ) { - continue; - } - } + let Some(func_id) = + methods.find_matching_method(typ, has_self_param, self.interner).or_else(|| { + // Also try to find a method assuming typ is `&mut typ`: + // we want to suggest methods that take `&mut self` even though a variable might not + // be mutable, so a user can know they need to mark it as mutable. + let typ = Type::MutableReference(Box::new(typ.clone())); + methods.find_matching_method(&typ, has_self_param, self.interner) + }) + else { + continue; + }; - if is_primitive - && !method_call_is_visible( - typ, - func_id, - self.module_id, - self.interner, - self.def_maps, - ) - { + if let Some(struct_id) = struct_id { + let modifiers = self.interner.function_modifiers(&func_id); + let visibility = modifiers.visibility; + if !struct_member_is_visible(struct_id, visibility, self.module_id, self.def_maps) { continue; } + } - if name_matches(name, prefix) { - let completion_items = self.function_completion_items( - name, - func_id, - function_completion_kind, - function_kind, - None, // attribute first type - self_prefix, - ); - if !completion_items.is_empty() { - self.completion_items.extend(completion_items); - self.suggested_module_def_ids.insert(ModuleDefId::FunctionId(func_id)); - } + if is_primitive + && !method_call_is_visible( + typ, + func_id, + self.module_id, + self.interner, + self.def_maps, + ) + { + continue; + } + + if name_matches(name, prefix) { + let completion_items = self.function_completion_items( + name, + func_id, + function_completion_kind, + function_kind, + None, // attribute first type + self_prefix, + ); + if !completion_items.is_empty() { + self.completion_items.extend(completion_items); + self.suggested_module_def_ids.insert(ModuleDefId::FunctionId(func_id)); } } } diff --git a/tooling/lsp/src/requests/completion/tests.rs b/tooling/lsp/src/requests/completion/tests.rs index 8cfb2a4b5ee..9306e38a48a 100644 --- a/tooling/lsp/src/requests/completion/tests.rs +++ b/tooling/lsp/src/requests/completion/tests.rs @@ -2780,4 +2780,37 @@ fn main() { ) .await; } + + #[test] + async fn test_suggests_methods_based_on_type_generics() { + let src = r#" + struct Foo { + t: T, + } + + impl Foo { + fn bar_baz(_self: Self) -> Field { + 5 + } + } + + impl Foo { + fn bar(_self: Self) -> Field { + 5 + } + + fn baz(_self: Self) -> Field { + 6 + } + } + + fn main() -> pub Field { + let foo: Foo = Foo { t: 5 }; + foo.b>|< + } + "#; + let items = get_completions(src).await; + assert_eq!(items.len(), 1); + assert!(items[0].label == "bar_baz()"); + } } diff --git a/tooling/lsp/src/requests/test_run.rs b/tooling/lsp/src/requests/test_run.rs index 50c699bb6a6..937fdcc0a5e 100644 --- a/tooling/lsp/src/requests/test_run.rs +++ b/tooling/lsp/src/requests/test_run.rs @@ -101,6 +101,11 @@ fn on_test_run_request_inner( result: "fail".to_string(), message: Some(message), }, + TestStatus::Skipped => NargoTestRunResult { + id: params.id.clone(), + result: "skipped".to_string(), + message: None, + }, TestStatus::CompileError(diag) => NargoTestRunResult { id: params.id.clone(), result: "error".to_string(), diff --git a/tooling/nargo/src/foreign_calls/mocker.rs b/tooling/nargo/src/foreign_calls/mocker.rs new file mode 100644 index 00000000000..c93d16bbaf6 --- /dev/null +++ b/tooling/nargo/src/foreign_calls/mocker.rs @@ -0,0 +1,176 @@ +use acvm::{ + acir::brillig::{ForeignCallParam, ForeignCallResult}, + pwg::ForeignCallWaitInfo, + AcirField, +}; +use noirc_printable_type::{decode_string_value, ForeignCallError}; +use serde::{Deserialize, Serialize}; + +use super::{ForeignCall, ForeignCallExecutor}; + +/// This struct represents an oracle mock. It can be used for testing programs that use oracles. +#[derive(Debug, PartialEq, Eq, Clone)] +struct MockedCall { + /// The id of the mock, used to update or remove it + id: usize, + /// The oracle it's mocking + name: String, + /// Optionally match the parameters + params: Option>>, + /// The parameters with which the mock was last called + last_called_params: Option>>, + /// The result to return when this mock is called + result: ForeignCallResult, + /// How many times should this mock be called before it is removed + times_left: Option, +} + +impl MockedCall { + fn new(id: usize, name: String) -> Self { + Self { + id, + name, + params: None, + last_called_params: None, + result: ForeignCallResult { values: vec![] }, + times_left: None, + } + } +} + +impl MockedCall { + fn matches(&self, name: &str, params: &[ForeignCallParam]) -> bool { + self.name == name && (self.params.is_none() || self.params.as_deref() == Some(params)) + } +} + +#[derive(Debug, Default)] +pub(crate) struct MockForeignCallExecutor { + /// Mocks have unique ids used to identify them in Noir, allowing to update or remove them. + last_mock_id: usize, + /// The registered mocks + mocked_responses: Vec>, +} + +impl MockForeignCallExecutor { + fn extract_mock_id( + foreign_call_inputs: &[ForeignCallParam], + ) -> Result<(usize, &[ForeignCallParam]), ForeignCallError> { + let (id, params) = + foreign_call_inputs.split_first().ok_or(ForeignCallError::MissingForeignCallInputs)?; + let id = + usize::try_from(id.unwrap_field().try_to_u64().expect("value does not fit into u64")) + .expect("value does not fit into usize"); + Ok((id, params)) + } + + fn find_mock_by_id(&self, id: usize) -> Option<&MockedCall> { + self.mocked_responses.iter().find(|response| response.id == id) + } + + fn find_mock_by_id_mut(&mut self, id: usize) -> Option<&mut MockedCall> { + self.mocked_responses.iter_mut().find(|response| response.id == id) + } + + fn parse_string(param: &ForeignCallParam) -> String { + let fields: Vec<_> = param.fields().to_vec(); + decode_string_value(&fields) + } +} + +impl Deserialize<'a>> ForeignCallExecutor + for MockForeignCallExecutor +{ + fn execute( + &mut self, + foreign_call: &ForeignCallWaitInfo, + ) -> Result, ForeignCallError> { + let foreign_call_name = foreign_call.function.as_str(); + match ForeignCall::lookup(foreign_call_name) { + Some(ForeignCall::CreateMock) => { + let mock_oracle_name = Self::parse_string(&foreign_call.inputs[0]); + assert!(ForeignCall::lookup(&mock_oracle_name).is_none()); + let id = self.last_mock_id; + self.mocked_responses.push(MockedCall::new(id, mock_oracle_name)); + self.last_mock_id += 1; + + Ok(F::from(id).into()) + } + Some(ForeignCall::SetMockParams) => { + let (id, params) = Self::extract_mock_id(&foreign_call.inputs)?; + self.find_mock_by_id_mut(id) + .unwrap_or_else(|| panic!("Unknown mock id {}", id)) + .params = Some(params.to_vec()); + + Ok(ForeignCallResult::default()) + } + Some(ForeignCall::GetMockLastParams) => { + let (id, _) = Self::extract_mock_id(&foreign_call.inputs)?; + let mock = + self.find_mock_by_id(id).unwrap_or_else(|| panic!("Unknown mock id {}", id)); + + let last_called_params = mock + .last_called_params + .clone() + .unwrap_or_else(|| panic!("Mock {} was never called", mock.name)); + + Ok(last_called_params.into()) + } + Some(ForeignCall::SetMockReturns) => { + let (id, params) = Self::extract_mock_id(&foreign_call.inputs)?; + self.find_mock_by_id_mut(id) + .unwrap_or_else(|| panic!("Unknown mock id {}", id)) + .result = ForeignCallResult { values: params.to_vec() }; + + Ok(ForeignCallResult::default()) + } + Some(ForeignCall::SetMockTimes) => { + let (id, params) = Self::extract_mock_id(&foreign_call.inputs)?; + let times = + params[0].unwrap_field().try_to_u64().expect("Invalid bit size of times"); + + self.find_mock_by_id_mut(id) + .unwrap_or_else(|| panic!("Unknown mock id {}", id)) + .times_left = Some(times); + + Ok(ForeignCallResult::default()) + } + Some(ForeignCall::ClearMock) => { + let (id, _) = Self::extract_mock_id(&foreign_call.inputs)?; + self.mocked_responses.retain(|response| response.id != id); + Ok(ForeignCallResult::default()) + } + _ => { + let mock_response_position = self + .mocked_responses + .iter() + .position(|response| response.matches(foreign_call_name, &foreign_call.inputs)); + + if let Some(response_position) = mock_response_position { + // If the program has registered a mocked response to this oracle call then we prefer responding + // with that. + + let mock = self + .mocked_responses + .get_mut(response_position) + .expect("Invalid position of mocked response"); + + mock.last_called_params = Some(foreign_call.inputs.clone()); + + let result = mock.result.values.clone(); + + if let Some(times_left) = &mut mock.times_left { + *times_left -= 1; + if *times_left == 0 { + self.mocked_responses.remove(response_position); + } + } + + Ok(result.into()) + } else { + Err(ForeignCallError::NoHandler(foreign_call_name.to_string())) + } + } + } + } +} diff --git a/tooling/nargo/src/foreign_calls/mod.rs b/tooling/nargo/src/foreign_calls/mod.rs new file mode 100644 index 00000000000..16ed71e11e3 --- /dev/null +++ b/tooling/nargo/src/foreign_calls/mod.rs @@ -0,0 +1,146 @@ +use std::path::PathBuf; + +use acvm::{acir::brillig::ForeignCallResult, pwg::ForeignCallWaitInfo, AcirField}; +use mocker::MockForeignCallExecutor; +use noirc_printable_type::ForeignCallError; +use print::PrintForeignCallExecutor; +use rand::Rng; +use rpc::RPCForeignCallExecutor; +use serde::{Deserialize, Serialize}; + +pub(crate) mod mocker; +pub(crate) mod print; +pub(crate) mod rpc; + +pub trait ForeignCallExecutor { + fn execute( + &mut self, + foreign_call: &ForeignCallWaitInfo, + ) -> Result, ForeignCallError>; +} + +/// This enumeration represents the Brillig foreign calls that are natively supported by nargo. +/// After resolution of a foreign call, nargo will restart execution of the ACVM +pub enum ForeignCall { + Print, + CreateMock, + SetMockParams, + GetMockLastParams, + SetMockReturns, + SetMockTimes, + ClearMock, +} + +impl std::fmt::Display for ForeignCall { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name()) + } +} + +impl ForeignCall { + pub(crate) fn name(&self) -> &'static str { + match self { + ForeignCall::Print => "print", + ForeignCall::CreateMock => "create_mock", + ForeignCall::SetMockParams => "set_mock_params", + ForeignCall::GetMockLastParams => "get_mock_last_params", + ForeignCall::SetMockReturns => "set_mock_returns", + ForeignCall::SetMockTimes => "set_mock_times", + ForeignCall::ClearMock => "clear_mock", + } + } + + pub(crate) fn lookup(op_name: &str) -> Option { + match op_name { + "print" => Some(ForeignCall::Print), + "create_mock" => Some(ForeignCall::CreateMock), + "set_mock_params" => Some(ForeignCall::SetMockParams), + "get_mock_last_params" => Some(ForeignCall::GetMockLastParams), + "set_mock_returns" => Some(ForeignCall::SetMockReturns), + "set_mock_times" => Some(ForeignCall::SetMockTimes), + "clear_mock" => Some(ForeignCall::ClearMock), + _ => None, + } + } +} + +#[derive(Debug, Default)] +pub struct DefaultForeignCallExecutor { + /// The executor for any [`ForeignCall::Print`] calls. + printer: Option, + mocker: MockForeignCallExecutor, + external: Option, +} + +impl DefaultForeignCallExecutor { + pub fn new( + show_output: bool, + resolver_url: Option<&str>, + root_path: Option, + package_name: Option, + ) -> Self { + let id = rand::thread_rng().gen(); + let printer = if show_output { Some(PrintForeignCallExecutor) } else { None }; + let external_resolver = resolver_url.map(|resolver_url| { + RPCForeignCallExecutor::new(resolver_url, id, root_path, package_name) + }); + DefaultForeignCallExecutor { + printer, + mocker: MockForeignCallExecutor::default(), + external: external_resolver, + } + } +} + +impl Deserialize<'a>> ForeignCallExecutor + for DefaultForeignCallExecutor +{ + fn execute( + &mut self, + foreign_call: &ForeignCallWaitInfo, + ) -> Result, ForeignCallError> { + let foreign_call_name = foreign_call.function.as_str(); + match ForeignCall::lookup(foreign_call_name) { + Some(ForeignCall::Print) => { + if let Some(printer) = &mut self.printer { + printer.execute(foreign_call) + } else { + Ok(ForeignCallResult::default()) + } + } + Some( + ForeignCall::CreateMock + | ForeignCall::SetMockParams + | ForeignCall::GetMockLastParams + | ForeignCall::SetMockReturns + | ForeignCall::SetMockTimes + | ForeignCall::ClearMock, + ) => self.mocker.execute(foreign_call), + + None => { + // First check if there's any defined mock responses for this foreign call. + match self.mocker.execute(foreign_call) { + Err(ForeignCallError::NoHandler(_)) => (), + response_or_error => return response_or_error, + }; + + if let Some(external_resolver) = &mut self.external { + // If the user has registered an external resolver then we forward any remaining oracle calls there. + match external_resolver.execute(foreign_call) { + Err(ForeignCallError::NoHandler(_)) => (), + response_or_error => return response_or_error, + }; + } + + // If all executors have no handler for the given foreign call then we cannot + // return a correct response to the ACVM. The best we can do is to return an empty response, + // this allows us to ignore any foreign calls which exist solely to pass information from inside + // the circuit to the environment (e.g. custom logging) as the execution will still be able to progress. + // + // We optimistically return an empty response for all oracle calls as the ACVM will error + // should a response have been required. + Ok(ForeignCallResult::default()) + } + } + } +} diff --git a/tooling/nargo/src/foreign_calls/print.rs b/tooling/nargo/src/foreign_calls/print.rs new file mode 100644 index 00000000000..92fcd65ae28 --- /dev/null +++ b/tooling/nargo/src/foreign_calls/print.rs @@ -0,0 +1,36 @@ +use acvm::{acir::brillig::ForeignCallResult, pwg::ForeignCallWaitInfo, AcirField}; +use noirc_printable_type::{ForeignCallError, PrintableValueDisplay}; + +use super::{ForeignCall, ForeignCallExecutor}; + +#[derive(Debug, Default)] +pub(crate) struct PrintForeignCallExecutor; + +impl ForeignCallExecutor for PrintForeignCallExecutor { + fn execute( + &mut self, + foreign_call: &ForeignCallWaitInfo, + ) -> Result, ForeignCallError> { + let foreign_call_name = foreign_call.function.as_str(); + match ForeignCall::lookup(foreign_call_name) { + Some(ForeignCall::Print) => { + let skip_newline = foreign_call.inputs[0].unwrap_field().is_zero(); + + let foreign_call_inputs = foreign_call + .inputs + .split_first() + .ok_or(ForeignCallError::MissingForeignCallInputs)? + .1; + + let display_values: PrintableValueDisplay = foreign_call_inputs.try_into()?; + let display_string = + format!("{display_values}{}", if skip_newline { "" } else { "\n" }); + + print!("{display_string}"); + + Ok(ForeignCallResult::default()) + } + _ => Err(ForeignCallError::NoHandler(foreign_call_name.to_string())), + } + } +} diff --git a/tooling/nargo/src/foreign_calls/rpc.rs b/tooling/nargo/src/foreign_calls/rpc.rs new file mode 100644 index 00000000000..0653eb1c7e3 --- /dev/null +++ b/tooling/nargo/src/foreign_calls/rpc.rs @@ -0,0 +1,227 @@ +use std::path::PathBuf; + +use acvm::{acir::brillig::ForeignCallResult, pwg::ForeignCallWaitInfo, AcirField}; +use jsonrpc::{arg as build_json_rpc_arg, minreq_http::Builder, Client}; +use noirc_printable_type::ForeignCallError; +use serde::{Deserialize, Serialize}; + +use super::ForeignCallExecutor; + +#[derive(Debug)] +pub(crate) struct RPCForeignCallExecutor { + /// A randomly generated id for this `DefaultForeignCallExecutor`. + /// + /// This is used so that a single `external_resolver` can distinguish between requests from multiple + /// instantiations of `DefaultForeignCallExecutor`. + id: u64, + /// JSON RPC client to resolve foreign calls + external_resolver: Client, + /// Root path to the program or workspace in execution. + root_path: Option, + /// Name of the package in execution + package_name: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ResolveForeignCallRequest { + /// A session ID which allows the external RPC server to link this foreign call request to other foreign calls + /// for the same program execution. + /// + /// This is intended to allow a single RPC server to maintain state related to multiple program executions being + /// performed in parallel. + session_id: u64, + + #[serde(flatten)] + /// The foreign call which the external RPC server is to provide a response for. + function_call: ForeignCallWaitInfo, + + #[serde(skip_serializing_if = "Option::is_none")] + /// Root path to the program or workspace in execution. + root_path: Option, + #[serde(skip_serializing_if = "Option::is_none")] + /// Name of the package in execution + package_name: Option, +} + +impl RPCForeignCallExecutor { + pub(crate) fn new( + resolver_url: &str, + id: u64, + root_path: Option, + package_name: Option, + ) -> Self { + let mut transport_builder = + Builder::new().url(resolver_url).expect("Invalid oracle resolver URL"); + + if let Some(Ok(timeout)) = + std::env::var("NARGO_FOREIGN_CALL_TIMEOUT").ok().map(|timeout| timeout.parse()) + { + let timeout_duration = std::time::Duration::from_millis(timeout); + transport_builder = transport_builder.timeout(timeout_duration); + }; + let oracle_resolver = Client::with_transport(transport_builder.build()); + + RPCForeignCallExecutor { external_resolver: oracle_resolver, id, root_path, package_name } + } +} + +impl Deserialize<'a>> ForeignCallExecutor + for RPCForeignCallExecutor +{ + fn execute( + &mut self, + foreign_call: &ForeignCallWaitInfo, + ) -> Result, ForeignCallError> { + let encoded_params = vec![build_json_rpc_arg(ResolveForeignCallRequest { + session_id: self.id, + function_call: foreign_call.clone(), + root_path: self.root_path.clone().map(|path| path.to_str().unwrap().to_string()), + package_name: self.package_name.clone(), + })]; + + let req = self.external_resolver.build_request("resolve_foreign_call", &encoded_params); + + let response = self.external_resolver.send_request(req)?; + + let parsed_response: ForeignCallResult = response.result()?; + + Ok(parsed_response) + } +} + +#[cfg(test)] +mod tests { + use acvm::{ + acir::brillig::ForeignCallParam, brillig_vm::brillig::ForeignCallResult, + pwg::ForeignCallWaitInfo, FieldElement, + }; + use jsonrpc_core::Result as RpcResult; + use jsonrpc_derive::rpc; + use jsonrpc_http_server::{Server, ServerBuilder}; + + use super::{ForeignCallExecutor, RPCForeignCallExecutor, ResolveForeignCallRequest}; + + #[allow(unreachable_pub)] + #[rpc] + pub trait OracleResolver { + #[rpc(name = "resolve_foreign_call")] + fn resolve_foreign_call( + &self, + req: ResolveForeignCallRequest, + ) -> RpcResult>; + } + + struct OracleResolverImpl; + + impl OracleResolverImpl { + fn echo(&self, param: ForeignCallParam) -> ForeignCallResult { + vec![param].into() + } + + fn sum(&self, array: ForeignCallParam) -> ForeignCallResult { + let mut res: FieldElement = 0_usize.into(); + + for value in array.fields() { + res += value; + } + + res.into() + } + } + + impl OracleResolver for OracleResolverImpl { + fn resolve_foreign_call( + &self, + req: ResolveForeignCallRequest, + ) -> RpcResult> { + let response = match req.function_call.function.as_str() { + "sum" => self.sum(req.function_call.inputs[0].clone()), + "echo" => self.echo(req.function_call.inputs[0].clone()), + "id" => FieldElement::from(req.session_id as u128).into(), + + _ => panic!("unexpected foreign call"), + }; + Ok(response) + } + } + + fn build_oracle_server() -> (Server, String) { + let mut io = jsonrpc_core::IoHandler::new(); + io.extend_with(OracleResolverImpl.to_delegate()); + + // Choosing port 0 results in a random port being assigned. + let server = ServerBuilder::new(io) + .start_http(&"127.0.0.1:0".parse().expect("Invalid address")) + .expect("Could not start server"); + + let url = format!("http://{}", server.address()); + (server, url) + } + + #[test] + fn test_oracle_resolver_echo() { + let (server, url) = build_oracle_server(); + + let mut executor = RPCForeignCallExecutor::new(&url, 1, None, None); + + let foreign_call: ForeignCallWaitInfo = ForeignCallWaitInfo { + function: "echo".to_string(), + inputs: vec![ForeignCallParam::Single(1_u128.into())], + }; + + let result = executor.execute(&foreign_call); + assert_eq!(result.unwrap(), ForeignCallResult { values: foreign_call.inputs }); + + server.close(); + } + + #[test] + fn test_oracle_resolver_sum() { + let (server, url) = build_oracle_server(); + + let mut executor = RPCForeignCallExecutor::new(&url, 2, None, None); + + let foreign_call: ForeignCallWaitInfo = ForeignCallWaitInfo { + function: "sum".to_string(), + inputs: vec![ForeignCallParam::Array(vec![1_usize.into(), 2_usize.into()])], + }; + + let result = executor.execute(&foreign_call); + assert_eq!(result.unwrap(), FieldElement::from(3_usize).into()); + + server.close(); + } + + #[test] + fn foreign_call_executor_id_is_persistent() { + let (server, url) = build_oracle_server(); + + let mut executor = RPCForeignCallExecutor::new(&url, 3, None, None); + + let foreign_call: ForeignCallWaitInfo = + ForeignCallWaitInfo { function: "id".to_string(), inputs: Vec::new() }; + + let result_1 = executor.execute(&foreign_call).unwrap(); + let result_2 = executor.execute(&foreign_call).unwrap(); + assert_eq!(result_1, result_2); + + server.close(); + } + + #[test] + fn oracle_resolver_rpc_can_distinguish_executors() { + let (server, url) = build_oracle_server(); + + let mut executor_1 = RPCForeignCallExecutor::new(&url, 4, None, None); + let mut executor_2 = RPCForeignCallExecutor::new(&url, 5, None, None); + + let foreign_call: ForeignCallWaitInfo = + ForeignCallWaitInfo { function: "id".to_string(), inputs: Vec::new() }; + + let result_1 = executor_1.execute(&foreign_call).unwrap(); + let result_2 = executor_2.execute(&foreign_call).unwrap(); + assert_ne!(result_1, result_2); + + server.close(); + } +} diff --git a/tooling/nargo/src/lib.rs b/tooling/nargo/src/lib.rs index 88f07e0c292..74b7f54d860 100644 --- a/tooling/nargo/src/lib.rs +++ b/tooling/nargo/src/lib.rs @@ -9,6 +9,7 @@ pub mod constants; pub mod errors; +pub mod foreign_calls; pub mod ops; pub mod package; pub mod workspace; diff --git a/tooling/nargo/src/ops/check.rs b/tooling/nargo/src/ops/check.rs index 14d629ab0f6..707353ccdad 100644 --- a/tooling/nargo/src/ops/check.rs +++ b/tooling/nargo/src/ops/check.rs @@ -2,8 +2,8 @@ use acvm::compiler::CircuitSimulator; use noirc_driver::{CompiledProgram, ErrorsAndWarnings}; use noirc_errors::{CustomDiagnostic, FileDiagnostic}; +/// Run each function through a circuit simulator to check that they are solvable. pub fn check_program(compiled_program: &CompiledProgram) -> Result<(), ErrorsAndWarnings> { - // Check if the program is solvable for (i, circuit) in compiled_program.program.functions.iter().enumerate() { let mut simulator = CircuitSimulator::default(); if !simulator.check_circuit(circuit) { diff --git a/tooling/nargo/src/ops/execute.rs b/tooling/nargo/src/ops/execute.rs index 09ef554d2aa..57116ec2efd 100644 --- a/tooling/nargo/src/ops/execute.rs +++ b/tooling/nargo/src/ops/execute.rs @@ -10,10 +10,9 @@ use acvm::{acir::circuit::Circuit, acir::native_types::WitnessMap}; use acvm::{AcirField, BlackBoxFunctionSolver}; use crate::errors::ExecutionError; +use crate::foreign_calls::ForeignCallExecutor; use crate::NargoError; -use super::foreign_calls::ForeignCallExecutor; - struct ProgramExecutor<'a, F, B: BlackBoxFunctionSolver, E: ForeignCallExecutor> { functions: &'a [Circuit], diff --git a/tooling/nargo/src/ops/foreign_calls.rs b/tooling/nargo/src/ops/foreign_calls.rs deleted file mode 100644 index 30785949a46..00000000000 --- a/tooling/nargo/src/ops/foreign_calls.rs +++ /dev/null @@ -1,494 +0,0 @@ -use std::path::PathBuf; - -use acvm::{ - acir::brillig::{ForeignCallParam, ForeignCallResult}, - pwg::ForeignCallWaitInfo, - AcirField, -}; -use jsonrpc::{arg as build_json_rpc_arg, minreq_http::Builder, Client}; -use noirc_printable_type::{decode_string_value, ForeignCallError, PrintableValueDisplay}; -use rand::Rng; -use serde::{Deserialize, Serialize}; - -pub trait ForeignCallExecutor { - fn execute( - &mut self, - foreign_call: &ForeignCallWaitInfo, - ) -> Result, ForeignCallError>; -} - -/// This enumeration represents the Brillig foreign calls that are natively supported by nargo. -/// After resolution of a foreign call, nargo will restart execution of the ACVM -pub enum ForeignCall { - Print, - CreateMock, - SetMockParams, - GetMockLastParams, - SetMockReturns, - SetMockTimes, - ClearMock, -} - -impl std::fmt::Display for ForeignCall { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.name()) - } -} - -impl ForeignCall { - pub(crate) fn name(&self) -> &'static str { - match self { - ForeignCall::Print => "print", - ForeignCall::CreateMock => "create_mock", - ForeignCall::SetMockParams => "set_mock_params", - ForeignCall::GetMockLastParams => "get_mock_last_params", - ForeignCall::SetMockReturns => "set_mock_returns", - ForeignCall::SetMockTimes => "set_mock_times", - ForeignCall::ClearMock => "clear_mock", - } - } - - pub(crate) fn lookup(op_name: &str) -> Option { - match op_name { - "print" => Some(ForeignCall::Print), - "create_mock" => Some(ForeignCall::CreateMock), - "set_mock_params" => Some(ForeignCall::SetMockParams), - "get_mock_last_params" => Some(ForeignCall::GetMockLastParams), - "set_mock_returns" => Some(ForeignCall::SetMockReturns), - "set_mock_times" => Some(ForeignCall::SetMockTimes), - "clear_mock" => Some(ForeignCall::ClearMock), - _ => None, - } - } -} - -/// This struct represents an oracle mock. It can be used for testing programs that use oracles. -#[derive(Debug, PartialEq, Eq, Clone)] -struct MockedCall { - /// The id of the mock, used to update or remove it - id: usize, - /// The oracle it's mocking - name: String, - /// Optionally match the parameters - params: Option>>, - /// The parameters with which the mock was last called - last_called_params: Option>>, - /// The result to return when this mock is called - result: ForeignCallResult, - /// How many times should this mock be called before it is removed - times_left: Option, -} - -impl MockedCall { - fn new(id: usize, name: String) -> Self { - Self { - id, - name, - params: None, - last_called_params: None, - result: ForeignCallResult { values: vec![] }, - times_left: None, - } - } -} - -impl MockedCall { - fn matches(&self, name: &str, params: &[ForeignCallParam]) -> bool { - self.name == name && (self.params.is_none() || self.params.as_deref() == Some(params)) - } -} - -#[derive(Debug, Default)] -pub struct DefaultForeignCallExecutor { - /// A randomly generated id for this `DefaultForeignCallExecutor`. - /// - /// This is used so that a single `external_resolver` can distinguish between requests from multiple - /// instantiations of `DefaultForeignCallExecutor`. - id: u64, - - /// Mocks have unique ids used to identify them in Noir, allowing to update or remove them. - last_mock_id: usize, - /// The registered mocks - mocked_responses: Vec>, - /// Whether to print [`ForeignCall::Print`] output. - show_output: bool, - /// JSON RPC client to resolve foreign calls - external_resolver: Option, - /// Root path to the program or workspace in execution. - root_path: Option, - /// Name of the package in execution - package_name: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -struct ResolveForeignCallRequest { - /// A session ID which allows the external RPC server to link this foreign call request to other foreign calls - /// for the same program execution. - /// - /// This is intended to allow a single RPC server to maintain state related to multiple program executions being - /// performed in parallel. - session_id: u64, - - #[serde(flatten)] - /// The foreign call which the external RPC server is to provide a response for. - function_call: ForeignCallWaitInfo, - - #[serde(skip_serializing_if = "Option::is_none")] - /// Root path to the program or workspace in execution. - root_path: Option, - #[serde(skip_serializing_if = "Option::is_none")] - /// Name of the package in execution - package_name: Option, -} - -impl DefaultForeignCallExecutor { - pub fn new( - show_output: bool, - resolver_url: Option<&str>, - root_path: Option, - package_name: Option, - ) -> Self { - let oracle_resolver = resolver_url.map(|resolver_url| { - let mut transport_builder = - Builder::new().url(resolver_url).expect("Invalid oracle resolver URL"); - - if let Some(Ok(timeout)) = - std::env::var("NARGO_FOREIGN_CALL_TIMEOUT").ok().map(|timeout| timeout.parse()) - { - let timeout_duration = std::time::Duration::from_millis(timeout); - transport_builder = transport_builder.timeout(timeout_duration); - }; - Client::with_transport(transport_builder.build()) - }); - DefaultForeignCallExecutor { - show_output, - external_resolver: oracle_resolver, - id: rand::thread_rng().gen(), - mocked_responses: Vec::new(), - last_mock_id: 0, - root_path, - package_name, - } - } -} - -impl DefaultForeignCallExecutor { - fn extract_mock_id( - foreign_call_inputs: &[ForeignCallParam], - ) -> Result<(usize, &[ForeignCallParam]), ForeignCallError> { - let (id, params) = - foreign_call_inputs.split_first().ok_or(ForeignCallError::MissingForeignCallInputs)?; - let id = - usize::try_from(id.unwrap_field().try_to_u64().expect("value does not fit into u64")) - .expect("value does not fit into usize"); - Ok((id, params)) - } - - fn find_mock_by_id(&self, id: usize) -> Option<&MockedCall> { - self.mocked_responses.iter().find(|response| response.id == id) - } - - fn find_mock_by_id_mut(&mut self, id: usize) -> Option<&mut MockedCall> { - self.mocked_responses.iter_mut().find(|response| response.id == id) - } - - fn parse_string(param: &ForeignCallParam) -> String { - let fields: Vec<_> = param.fields().to_vec(); - decode_string_value(&fields) - } - - fn execute_print(foreign_call_inputs: &[ForeignCallParam]) -> Result<(), ForeignCallError> { - let skip_newline = foreign_call_inputs[0].unwrap_field().is_zero(); - - let foreign_call_inputs = - foreign_call_inputs.split_first().ok_or(ForeignCallError::MissingForeignCallInputs)?.1; - let display_string = Self::format_printable_value(foreign_call_inputs, skip_newline)?; - - print!("{display_string}"); - - Ok(()) - } - - fn format_printable_value( - foreign_call_inputs: &[ForeignCallParam], - skip_newline: bool, - ) -> Result { - let display_values: PrintableValueDisplay = foreign_call_inputs.try_into()?; - - let result = format!("{display_values}{}", if skip_newline { "" } else { "\n" }); - - Ok(result) - } -} - -impl Deserialize<'a>> ForeignCallExecutor - for DefaultForeignCallExecutor -{ - fn execute( - &mut self, - foreign_call: &ForeignCallWaitInfo, - ) -> Result, ForeignCallError> { - let foreign_call_name = foreign_call.function.as_str(); - match ForeignCall::lookup(foreign_call_name) { - Some(ForeignCall::Print) => { - if self.show_output { - Self::execute_print(&foreign_call.inputs)?; - } - Ok(ForeignCallResult::default()) - } - Some(ForeignCall::CreateMock) => { - let mock_oracle_name = Self::parse_string(&foreign_call.inputs[0]); - assert!(ForeignCall::lookup(&mock_oracle_name).is_none()); - let id = self.last_mock_id; - self.mocked_responses.push(MockedCall::new(id, mock_oracle_name)); - self.last_mock_id += 1; - - Ok(F::from(id).into()) - } - Some(ForeignCall::SetMockParams) => { - let (id, params) = Self::extract_mock_id(&foreign_call.inputs)?; - self.find_mock_by_id_mut(id) - .unwrap_or_else(|| panic!("Unknown mock id {}", id)) - .params = Some(params.to_vec()); - - Ok(ForeignCallResult::default()) - } - Some(ForeignCall::GetMockLastParams) => { - let (id, _) = Self::extract_mock_id(&foreign_call.inputs)?; - let mock = - self.find_mock_by_id(id).unwrap_or_else(|| panic!("Unknown mock id {}", id)); - - let last_called_params = mock - .last_called_params - .clone() - .unwrap_or_else(|| panic!("Mock {} was never called", mock.name)); - - Ok(last_called_params.into()) - } - Some(ForeignCall::SetMockReturns) => { - let (id, params) = Self::extract_mock_id(&foreign_call.inputs)?; - self.find_mock_by_id_mut(id) - .unwrap_or_else(|| panic!("Unknown mock id {}", id)) - .result = ForeignCallResult { values: params.to_vec() }; - - Ok(ForeignCallResult::default()) - } - Some(ForeignCall::SetMockTimes) => { - let (id, params) = Self::extract_mock_id(&foreign_call.inputs)?; - let times = - params[0].unwrap_field().try_to_u64().expect("Invalid bit size of times"); - - self.find_mock_by_id_mut(id) - .unwrap_or_else(|| panic!("Unknown mock id {}", id)) - .times_left = Some(times); - - Ok(ForeignCallResult::default()) - } - Some(ForeignCall::ClearMock) => { - let (id, _) = Self::extract_mock_id(&foreign_call.inputs)?; - self.mocked_responses.retain(|response| response.id != id); - Ok(ForeignCallResult::default()) - } - None => { - let mock_response_position = self - .mocked_responses - .iter() - .position(|response| response.matches(foreign_call_name, &foreign_call.inputs)); - - if let Some(response_position) = mock_response_position { - // If the program has registered a mocked response to this oracle call then we prefer responding - // with that. - - let mock = self - .mocked_responses - .get_mut(response_position) - .expect("Invalid position of mocked response"); - - mock.last_called_params = Some(foreign_call.inputs.clone()); - - let result = mock.result.values.clone(); - - if let Some(times_left) = &mut mock.times_left { - *times_left -= 1; - if *times_left == 0 { - self.mocked_responses.remove(response_position); - } - } - - Ok(result.into()) - } else if let Some(external_resolver) = &self.external_resolver { - // If the user has registered an external resolver then we forward any remaining oracle calls there. - - let encoded_params = vec![build_json_rpc_arg(ResolveForeignCallRequest { - session_id: self.id, - function_call: foreign_call.clone(), - root_path: self - .root_path - .clone() - .map(|path| path.to_str().unwrap().to_string()), - package_name: self.package_name.clone(), - })]; - - let req = - external_resolver.build_request("resolve_foreign_call", &encoded_params); - - let response = external_resolver.send_request(req)?; - - let parsed_response: ForeignCallResult = response.result()?; - - Ok(parsed_response) - } else { - // If there's no registered mock oracle response and no registered resolver then we cannot - // return a correct response to the ACVM. The best we can do is to return an empty response, - // this allows us to ignore any foreign calls which exist solely to pass information from inside - // the circuit to the environment (e.g. custom logging) as the execution will still be able to progress. - // - // We optimistically return an empty response for all oracle calls as the ACVM will error - // should a response have been required. - Ok(ForeignCallResult::default()) - } - } - } - } -} - -#[cfg(test)] -mod tests { - use acvm::{ - acir::brillig::ForeignCallParam, brillig_vm::brillig::ForeignCallResult, - pwg::ForeignCallWaitInfo, FieldElement, - }; - use jsonrpc_core::Result as RpcResult; - use jsonrpc_derive::rpc; - use jsonrpc_http_server::{Server, ServerBuilder}; - - use crate::ops::{DefaultForeignCallExecutor, ForeignCallExecutor}; - - use super::ResolveForeignCallRequest; - - #[allow(unreachable_pub)] - #[rpc] - pub trait OracleResolver { - #[rpc(name = "resolve_foreign_call")] - fn resolve_foreign_call( - &self, - req: ResolveForeignCallRequest, - ) -> RpcResult>; - } - - struct OracleResolverImpl; - - impl OracleResolverImpl { - fn echo(&self, param: ForeignCallParam) -> ForeignCallResult { - vec![param].into() - } - - fn sum(&self, array: ForeignCallParam) -> ForeignCallResult { - let mut res: FieldElement = 0_usize.into(); - - for value in array.fields() { - res += value; - } - - res.into() - } - } - - impl OracleResolver for OracleResolverImpl { - fn resolve_foreign_call( - &self, - req: ResolveForeignCallRequest, - ) -> RpcResult> { - let response = match req.function_call.function.as_str() { - "sum" => self.sum(req.function_call.inputs[0].clone()), - "echo" => self.echo(req.function_call.inputs[0].clone()), - "id" => FieldElement::from(req.session_id as u128).into(), - - _ => panic!("unexpected foreign call"), - }; - Ok(response) - } - } - - fn build_oracle_server() -> (Server, String) { - let mut io = jsonrpc_core::IoHandler::new(); - io.extend_with(OracleResolverImpl.to_delegate()); - - // Choosing port 0 results in a random port being assigned. - let server = ServerBuilder::new(io) - .start_http(&"127.0.0.1:0".parse().expect("Invalid address")) - .expect("Could not start server"); - - let url = format!("http://{}", server.address()); - (server, url) - } - - #[test] - fn test_oracle_resolver_echo() { - let (server, url) = build_oracle_server(); - - let mut executor = - DefaultForeignCallExecutor::::new(false, Some(&url), None, None); - - let foreign_call = ForeignCallWaitInfo { - function: "echo".to_string(), - inputs: vec![ForeignCallParam::Single(1_u128.into())], - }; - - let result = executor.execute(&foreign_call); - assert_eq!(result.unwrap(), ForeignCallResult { values: foreign_call.inputs }); - - server.close(); - } - - #[test] - fn test_oracle_resolver_sum() { - let (server, url) = build_oracle_server(); - - let mut executor = DefaultForeignCallExecutor::new(false, Some(&url), None, None); - - let foreign_call = ForeignCallWaitInfo { - function: "sum".to_string(), - inputs: vec![ForeignCallParam::Array(vec![1_usize.into(), 2_usize.into()])], - }; - - let result = executor.execute(&foreign_call); - assert_eq!(result.unwrap(), FieldElement::from(3_usize).into()); - - server.close(); - } - - #[test] - fn foreign_call_executor_id_is_persistent() { - let (server, url) = build_oracle_server(); - - let mut executor = - DefaultForeignCallExecutor::::new(false, Some(&url), None, None); - - let foreign_call = ForeignCallWaitInfo { function: "id".to_string(), inputs: Vec::new() }; - - let result_1 = executor.execute(&foreign_call).unwrap(); - let result_2 = executor.execute(&foreign_call).unwrap(); - assert_eq!(result_1, result_2); - - server.close(); - } - - #[test] - fn oracle_resolver_rpc_can_distinguish_executors() { - let (server, url) = build_oracle_server(); - - let mut executor_1 = - DefaultForeignCallExecutor::::new(false, Some(&url), None, None); - let mut executor_2 = - DefaultForeignCallExecutor::::new(false, Some(&url), None, None); - - let foreign_call = ForeignCallWaitInfo { function: "id".to_string(), inputs: Vec::new() }; - - let result_1 = executor_1.execute(&foreign_call).unwrap(); - let result_2 = executor_2.execute(&foreign_call).unwrap(); - assert_ne!(result_1, result_2); - - server.close(); - } -} diff --git a/tooling/nargo/src/ops/mod.rs b/tooling/nargo/src/ops/mod.rs index f70577a14f1..04efeb5a9ec 100644 --- a/tooling/nargo/src/ops/mod.rs +++ b/tooling/nargo/src/ops/mod.rs @@ -4,7 +4,6 @@ pub use self::compile::{ compile_workspace, report_errors, }; pub use self::execute::{execute_program, execute_program_with_profiling}; -pub use self::foreign_calls::{DefaultForeignCallExecutor, ForeignCall, ForeignCallExecutor}; pub use self::optimize::{optimize_contract, optimize_program}; pub use self::transform::{transform_contract, transform_program}; @@ -13,7 +12,6 @@ pub use self::test::{run_test, TestStatus}; mod check; mod compile; mod execute; -mod foreign_calls; mod optimize; mod test; mod transform; diff --git a/tooling/nargo/src/ops/test.rs b/tooling/nargo/src/ops/test.rs index 370a4235f61..e258627b522 100644 --- a/tooling/nargo/src/ops/test.rs +++ b/tooling/nargo/src/ops/test.rs @@ -1,27 +1,42 @@ use std::path::PathBuf; use acvm::{ - acir::native_types::{WitnessMap, WitnessStack}, - BlackBoxFunctionSolver, FieldElement, + acir::{ + brillig::ForeignCallResult, + native_types::{WitnessMap, WitnessStack}, + }, + pwg::ForeignCallWaitInfo, + AcirField, BlackBoxFunctionSolver, FieldElement, }; use noirc_abi::Abi; use noirc_driver::{compile_no_check, CompileError, CompileOptions}; use noirc_errors::{debug_info::DebugInfo, FileDiagnostic}; use noirc_frontend::hir::{def_map::TestFunction, Context}; +use noirc_printable_type::ForeignCallError; +use rand::Rng; +use serde::{Deserialize, Serialize}; -use crate::{errors::try_to_diagnose_runtime_error, NargoError}; +use crate::{ + errors::try_to_diagnose_runtime_error, + foreign_calls::{ + mocker::MockForeignCallExecutor, print::PrintForeignCallExecutor, + rpc::RPCForeignCallExecutor, ForeignCall, ForeignCallExecutor, + }, + NargoError, +}; -use super::{execute_program, DefaultForeignCallExecutor}; +use super::execute_program; pub enum TestStatus { Pass, Fail { message: String, error_diagnostic: Option }, + Skipped, CompileError(FileDiagnostic), } impl TestStatus { pub fn failed(&self) -> bool { - !matches!(self, TestStatus::Pass) + !matches!(self, TestStatus::Pass | TestStatus::Skipped) } } @@ -48,23 +63,42 @@ pub fn run_test>( if test_function_has_no_arguments { // Run the backend to ensure the PWG evaluates functions like std::hash::pedersen, // otherwise constraints involving these expressions will not error. + let mut foreign_call_executor = TestForeignCallExecutor::new( + show_output, + foreign_call_resolver_url, + root_path, + package_name, + ); + let circuit_execution = execute_program( &compiled_program.program, WitnessMap::new(), blackbox_solver, - &mut DefaultForeignCallExecutor::new( - show_output, - foreign_call_resolver_url, - root_path, - package_name, - ), + &mut foreign_call_executor, ); - test_status_program_compile_pass( + + let status = test_status_program_compile_pass( test_function, compiled_program.abi, compiled_program.debug, circuit_execution, - ) + ); + + let ignore_foreign_call_failures = + std::env::var("NARGO_IGNORE_TEST_FAILURES_FROM_FOREIGN_CALLS") + .is_ok_and(|var| &var == "true"); + + if let TestStatus::Fail { .. } = status { + if ignore_foreign_call_failures + && foreign_call_executor.encountered_unknown_foreign_call + { + TestStatus::Skipped + } else { + status + } + } else { + status + } } else { #[cfg(target_arch = "wasm32")] { @@ -90,7 +124,7 @@ pub fn run_test>( program, initial_witness, blackbox_solver, - &mut DefaultForeignCallExecutor::::new( + &mut TestForeignCallExecutor::::new( false, foreign_call_resolver_url, root_path.clone(), @@ -215,3 +249,93 @@ fn check_expected_failure_message( error_diagnostic, } } + +/// A specialized foreign call executor which tracks whether it has encountered any unknown foreign calls +struct TestForeignCallExecutor { + /// The executor for any [`ForeignCall::Print`] calls. + printer: Option, + mocker: MockForeignCallExecutor, + external: Option, + + encountered_unknown_foreign_call: bool, +} + +impl TestForeignCallExecutor { + fn new( + show_output: bool, + resolver_url: Option<&str>, + root_path: Option, + package_name: Option, + ) -> Self { + let id = rand::thread_rng().gen(); + let printer = if show_output { Some(PrintForeignCallExecutor) } else { None }; + let external_resolver = resolver_url.map(|resolver_url| { + RPCForeignCallExecutor::new(resolver_url, id, root_path, package_name) + }); + TestForeignCallExecutor { + printer, + mocker: MockForeignCallExecutor::default(), + external: external_resolver, + encountered_unknown_foreign_call: false, + } + } +} + +impl Deserialize<'a>> ForeignCallExecutor + for TestForeignCallExecutor +{ + fn execute( + &mut self, + foreign_call: &ForeignCallWaitInfo, + ) -> Result, ForeignCallError> { + // If the circuit has reached a new foreign call opcode then it can't have failed from any previous unknown foreign calls. + self.encountered_unknown_foreign_call = false; + + let foreign_call_name = foreign_call.function.as_str(); + match ForeignCall::lookup(foreign_call_name) { + Some(ForeignCall::Print) => { + if let Some(printer) = &mut self.printer { + printer.execute(foreign_call) + } else { + Ok(ForeignCallResult::default()) + } + } + + Some( + ForeignCall::CreateMock + | ForeignCall::SetMockParams + | ForeignCall::GetMockLastParams + | ForeignCall::SetMockReturns + | ForeignCall::SetMockTimes + | ForeignCall::ClearMock, + ) => self.mocker.execute(foreign_call), + + None => { + // First check if there's any defined mock responses for this foreign call. + match self.mocker.execute(foreign_call) { + Err(ForeignCallError::NoHandler(_)) => (), + response_or_error => return response_or_error, + }; + + if let Some(external_resolver) = &mut self.external { + // If the user has registered an external resolver then we forward any remaining oracle calls there. + match external_resolver.execute(foreign_call) { + Err(ForeignCallError::NoHandler(_)) => (), + response_or_error => return response_or_error, + }; + } + + self.encountered_unknown_foreign_call = true; + + // If all executors have no handler for the given foreign call then we cannot + // return a correct response to the ACVM. The best we can do is to return an empty response, + // this allows us to ignore any foreign calls which exist solely to pass information from inside + // the circuit to the environment (e.g. custom logging) as the execution will still be able to progress. + // + // We optimistically return an empty response for all oracle calls as the ACVM will error + // should a response have been required. + Ok(ForeignCallResult::default()) + } + } + } +} diff --git a/tooling/nargo/src/ops/transform.rs b/tooling/nargo/src/ops/transform.rs index 9255ac3e0ec..fdda368d150 100644 --- a/tooling/nargo/src/ops/transform.rs +++ b/tooling/nargo/src/ops/transform.rs @@ -6,6 +6,7 @@ use iter_extended::vecmap; use noirc_driver::{CompiledContract, CompiledProgram}; use noirc_errors::debug_info::DebugInfo; +/// Apply ACVM optimizations on the circuit. pub fn transform_program( mut compiled_program: CompiledProgram, expression_width: ExpressionWidth, @@ -18,6 +19,7 @@ pub fn transform_program( compiled_program } +/// Apply the optimizing transformation on each function in the contract. pub fn transform_contract( contract: CompiledContract, expression_width: ExpressionWidth, @@ -25,7 +27,6 @@ pub fn transform_contract( let functions = vecmap(contract.functions, |mut func| { func.bytecode = transform_program_internal(func.bytecode, &mut func.debug, expression_width); - func }); diff --git a/tooling/nargo_cli/Cargo.toml b/tooling/nargo_cli/Cargo.toml index 02e669f5c68..5603b7f4fca 100644 --- a/tooling/nargo_cli/Cargo.toml +++ b/tooling/nargo_cli/Cargo.toml @@ -25,6 +25,7 @@ toml.workspace = true [dependencies] clap.workspace = true fm.workspace = true +fxhash.workspace = true iter-extended.workspace = true nargo.workspace = true nargo_fmt.workspace = true diff --git a/tooling/nargo_cli/benches/criterion.rs b/tooling/nargo_cli/benches/criterion.rs index 488cbfcd243..51de97df139 100644 --- a/tooling/nargo_cli/benches/criterion.rs +++ b/tooling/nargo_cli/benches/criterion.rs @@ -115,7 +115,7 @@ fn criterion_test_execution(c: &mut Criterion, test_program_dir: &Path, force_br let artifacts = RefCell::new(None); let mut foreign_call_executor = - nargo::ops::DefaultForeignCallExecutor::new(false, None, None, None); + nargo::foreign_calls::DefaultForeignCallExecutor::new(false, None, None, None); c.bench_function(&benchmark_name, |b| { b.iter_batched( diff --git a/tooling/nargo_cli/build.rs b/tooling/nargo_cli/build.rs index 04eb0ea0f52..f0334eaf713 100644 --- a/tooling/nargo_cli/build.rs +++ b/tooling/nargo_cli/build.rs @@ -86,7 +86,14 @@ fn read_test_cases( let test_case_dirs = fs::read_dir(test_data_dir).unwrap().flatten().filter(|c| c.path().is_dir()); - test_case_dirs.into_iter().map(|dir| { + test_case_dirs.into_iter().filter_map(|dir| { + // When switching git branches we might end up with non-empty directories that have a `target` + // directory inside them but no `Nargo.toml`. + // These "tests" would always fail, but it's okay to ignore them so we do that here. + if !dir.path().join("Nargo.toml").exists() { + return None; + } + let test_name = dir.file_name().into_string().expect("Directory can't be converted to string"); if test_name.contains('-') { @@ -94,7 +101,7 @@ fn read_test_cases( "Invalid test directory: {test_name}. Cannot include `-`, please convert to `_`" ); } - (test_name, dir.path()) + Some((test_name, dir.path())) }) } @@ -206,8 +213,13 @@ fn test_{test_name}(force_brillig: ForceBrillig, inliner_aggressiveness: Inliner nargo.arg("--program-dir").arg(test_program_dir); nargo.arg("{test_command}").arg("--force"); nargo.arg("--inliner-aggressiveness").arg(inliner_aggressiveness.0.to_string()); + if force_brillig.0 {{ nargo.arg("--force-brillig"); + + // Set the maximum increase so that part of the optimization is exercised (it might fail). + nargo.arg("--max-bytecode-increase-percent"); + nargo.arg("50"); }} {test_content} diff --git a/tooling/nargo_cli/src/cli/compile_cmd.rs b/tooling/nargo_cli/src/cli/compile_cmd.rs index 304988ed516..ff6009981c7 100644 --- a/tooling/nargo_cli/src/cli/compile_cmd.rs +++ b/tooling/nargo_cli/src/cli/compile_cmd.rs @@ -65,6 +65,7 @@ pub(crate) fn run(args: CompileCommand, config: NargoConfig) -> Result<(), CliEr Ok(()) } +/// Continuously recompile the workspace on any Noir file change event. fn watch_workspace(workspace: &Workspace, compile_options: &CompileOptions) -> notify::Result<()> { let (tx, rx) = std::sync::mpsc::channel(); @@ -108,6 +109,8 @@ fn watch_workspace(workspace: &Workspace, compile_options: &CompileOptions) -> n Ok(()) } +/// Parse and compile the entire workspace, then report errors. +/// This is the main entry point used by all other commands that need compilation. pub(super) fn compile_workspace_full( workspace: &Workspace, compile_options: &CompileOptions, @@ -129,6 +132,8 @@ pub(super) fn compile_workspace_full( Ok(()) } +/// Compile binary and contract packages. +/// Returns the merged warnings or errors. fn compile_workspace( file_manager: &FileManager, parsed_files: &ParsedFiles, @@ -144,6 +149,7 @@ fn compile_workspace( // Compile all of the packages in parallel. let program_warnings_or_errors: CompilationResult<()> = compile_programs(file_manager, parsed_files, workspace, &binary_packages, compile_options); + let contract_warnings_or_errors: CompilationResult<()> = compiled_contracts( file_manager, parsed_files, @@ -164,6 +170,7 @@ fn compile_workspace( } } +/// Compile the given binary packages in the workspace. fn compile_programs( file_manager: &FileManager, parsed_files: &ParsedFiles, @@ -171,6 +178,8 @@ fn compile_programs( binary_packages: &[Package], compile_options: &CompileOptions, ) -> CompilationResult<()> { + // Load any existing artifact for a given package, _iff_ it was compiled with the same nargo version. + // The loaded circuit includes backend specific transformations, which might be different from the current target. let load_cached_program = |package| { let program_artifact_path = workspace.package_build_path(package); read_program_from_file(program_artifact_path) @@ -180,19 +189,45 @@ fn compile_programs( }; let compile_package = |package| { + let cached_program = load_cached_program(package); + + // Hash over the entire compiled program, including any post-compile transformations. + // This is used to detect whether `cached_program` is returned by `compile_program`. + let cached_hash = cached_program.as_ref().map(fxhash::hash64); + + // Compile the program, or use the cached artifacts if it matches. let (program, warnings) = compile_program( file_manager, parsed_files, workspace, package, compile_options, - load_cached_program(package), + cached_program, )?; + // Choose the target width for the final, backend specific transformation. let target_width = get_target_width(package.expression_width, compile_options.expression_width); + + // If the compiled program is the same as the cached one, we don't apply transformations again, unless the target width has changed. + // The transformations might not be idempotent, which would risk creating witnesses that don't work with earlier versions, + // based on which we might have generated a verifier already. + if cached_hash == Some(fxhash::hash64(&program)) { + let width_matches = program + .program + .functions + .iter() + .all(|circuit| circuit.expression_width == target_width); + + if width_matches { + return Ok(((), warnings)); + } + } + // Run ACVM optimizations and set the target width. let program = nargo::ops::transform_program(program, target_width); + // Check solvability. nargo::ops::check_program(&program)?; + // Overwrite the build artifacts with the final circuit, which includes the backend specific transformations. save_program_to_file(&program.into(), &package.name, workspace.target_directory_path()); Ok(((), warnings)) @@ -208,6 +243,7 @@ fn compile_programs( collect_errors(program_results).map(|(_, warnings)| ((), warnings)) } +/// Compile the given contracts in the workspace. fn compiled_contracts( file_manager: &FileManager, parsed_files: &ParsedFiles, diff --git a/tooling/nargo_cli/src/cli/execute_cmd.rs b/tooling/nargo_cli/src/cli/execute_cmd.rs index 8dc71b1c7e5..fa95d3123c6 100644 --- a/tooling/nargo_cli/src/cli/execute_cmd.rs +++ b/tooling/nargo_cli/src/cli/execute_cmd.rs @@ -7,7 +7,7 @@ use clap::Args; use nargo::constants::PROVER_INPUT_FILE; use nargo::errors::try_to_diagnose_runtime_error; -use nargo::ops::DefaultForeignCallExecutor; +use nargo::foreign_calls::DefaultForeignCallExecutor; use nargo::package::{CrateName, Package}; use nargo_toml::{get_package_manifest, resolve_workspace_from_toml, PackageSelection}; use noirc_abi::input_parser::{Format, InputValue}; diff --git a/tooling/nargo_cli/src/cli/info_cmd.rs b/tooling/nargo_cli/src/cli/info_cmd.rs index cf416b1fa5f..769a1f79d81 100644 --- a/tooling/nargo_cli/src/cli/info_cmd.rs +++ b/tooling/nargo_cli/src/cli/info_cmd.rs @@ -4,7 +4,7 @@ use clap::Args; use iter_extended::vecmap; use nargo::{ constants::PROVER_INPUT_FILE, - ops::DefaultForeignCallExecutor, + foreign_calls::DefaultForeignCallExecutor, package::{CrateName, Package}, }; use nargo_toml::{get_package_manifest, resolve_workspace_from_toml, PackageSelection}; diff --git a/tooling/nargo_cli/src/cli/test_cmd.rs b/tooling/nargo_cli/src/cli/test_cmd.rs index 7b0201226ef..aa0ee1bb94b 100644 --- a/tooling/nargo_cli/src/cli/test_cmd.rs +++ b/tooling/nargo_cli/src/cli/test_cmd.rs @@ -255,6 +255,12 @@ fn display_test_report( ); } } + TestStatus::Skipped { .. } => { + writer + .set_color(ColorSpec::new().set_fg(Some(Color::Yellow))) + .expect("Failed to set color"); + writeln!(writer, "skipped").expect("Failed to write to stderr"); + } TestStatus::CompileError(err) => { noirc_errors::reporter::report_all( file_manager.as_file_map(), diff --git a/tooling/nargo_cli/tests/stdlib-props.rs b/tooling/nargo_cli/tests/stdlib-props.rs index 0013a90b4ff..86c225831b9 100644 --- a/tooling/nargo_cli/tests/stdlib-props.rs +++ b/tooling/nargo_cli/tests/stdlib-props.rs @@ -2,10 +2,7 @@ use std::{cell::RefCell, collections::BTreeMap, path::Path}; use acvm::{acir::native_types::WitnessStack, AcirField, FieldElement}; use iter_extended::vecmap; -use nargo::{ - ops::{execute_program, DefaultForeignCallExecutor}, - parse_all, -}; +use nargo::{foreign_calls::DefaultForeignCallExecutor, ops::execute_program, parse_all}; use noirc_abi::input_parser::InputValue; use noirc_driver::{ compile_main, file_manager_with_stdlib, prepare_crate, CompilationResult, CompileOptions, @@ -64,6 +61,7 @@ fn prepare_and_compile_snippet( ) -> CompilationResult { let (mut context, root_crate_id) = prepare_snippet(source); let options = CompileOptions { force_brillig, ..Default::default() }; + // TODO: Run nargo::ops::transform_program? compile_main(&mut context, root_crate_id, &options, None) } diff --git a/tooling/nargo_cli/tests/stdlib-tests.rs b/tooling/nargo_cli/tests/stdlib-tests.rs index bdc92e625ab..99f0c9a2e7f 100644 --- a/tooling/nargo_cli/tests/stdlib-tests.rs +++ b/tooling/nargo_cli/tests/stdlib-tests.rs @@ -138,6 +138,12 @@ fn display_test_report( ); } } + TestStatus::Skipped { .. } => { + writer + .set_color(ColorSpec::new().set_fg(Some(Color::Yellow))) + .expect("Failed to set color"); + writeln!(writer, "skipped").expect("Failed to write to stderr"); + } TestStatus::CompileError(err) => { noirc_errors::reporter::report_all( file_manager.as_file_map(), diff --git a/tooling/noirc_abi/Cargo.toml b/tooling/noirc_abi/Cargo.toml index a7baf334bff..22114408e18 100644 --- a/tooling/noirc_abi/Cargo.toml +++ b/tooling/noirc_abi/Cargo.toml @@ -23,8 +23,8 @@ num-bigint = "0.4" num-traits = "0.2" [dev-dependencies] -strum = "0.24" -strum_macros = "0.24" +strum.workspace = true +strum_macros.workspace = true proptest.workspace = true proptest-derive.workspace = true diff --git a/tooling/noirc_abi/proptest-regressions/input_parser/json.txt b/tooling/noirc_abi/proptest-regressions/input_parser/json.txt new file mode 100644 index 00000000000..19de8eeaf48 --- /dev/null +++ b/tooling/noirc_abi/proptest-regressions/input_parser/json.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc b3f9ae88d54944ca274764f4d99a2023d4b0ac09beb89bc599cbba1e45dd3620 # shrinks to (typ, value) = (Integer { sign: Signed, width: 1 }, -1) diff --git a/tooling/noirc_abi/proptest-regressions/input_parser/toml.txt b/tooling/noirc_abi/proptest-regressions/input_parser/toml.txt new file mode 100644 index 00000000000..1448cb67ef1 --- /dev/null +++ b/tooling/noirc_abi/proptest-regressions/input_parser/toml.txt @@ -0,0 +1,9 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 9d200afb8f5c01e3414d24eebe1436a7eef5377a46a9a9235aaa7f81e0b33656 # shrinks to (typ, value) = (Integer { sign: Signed, width: 8 }, -1) +cc 7fd29637e5566d819992185c1a95438e9949a555928a911b3918eed2e3f7a1fd # shrinks to (typ, value) = (Integer { sign: Signed, width: 64 }, -1) +cc 8ecbda39d887674b53ca23a861ac30fbb10c123bb70c57e69b336c86a3d9dea8 # shrinks to (abi, input_map) = (Abi { parameters: [AbiParameter { name: "¡", typ: Struct { path: "�)\u{1b}=�?Ⱥ\u{59424}?{\u{e4d5e}%Ѩ/Q\u{36a17}/*\";\u{b}&iC_\u{d313f}S\u{1b}\u{9dfec}\r/\u{10530d}", fields: [("?p*\"/\u{202e}\u{6f038}\u{537ca}.y@~𘛶?4\u{1b}*", Field), (".Ⱥ/$\u{7f}\u{103c06}%\\\u{202e}][0\u{88479}]\"*~\u{36fd5}\u{5}\u{feff}]{/", Tuple { fields: [String { length: 937 }] }), ("r\u{ac3a5}&:", Boolean), ("$d6🕴/:|�\u{37f8b}\r\u{a13b7}C$𲁹\\&\u{f8712}?\u{db61c}t%\u{57be1}\0", Field), ("/\u{6378b}\u{a426c}¥\u{7}/\u{fcb29}$\u{53c6b}\u{12d6f}\u{12bd3}.\u{f2f82}\u{8613e}*$\u{fd32f}\u{e29f7}\0𨺉'¬\"1", Struct { path: "\\\u{4a5ac}<\u{9e505}\u{4f3af}🕴&?<:^\u{7}\u{88}\u{3e1ff}(¥\u{531f3}K{:¥𦺀", fields: [("n\0Ѩ/\u{1b}𥐰\u{a4906}�¥`{\u{389d4}`1\u{7708a})\u{3dac4}8\u{93e5f}㒭\\\"\u{e6824}\u{b}Ѩ\u{88946}Ⱥ{", Integer { sign: Signed, width: 127 })] }), ("¥🕴\u{1b}¥🕴=sR\0\u{35f36}\u{867dc}>ä\u{202e}f:BȺ?:``*·¥\u{74ca5}\"", Tuple { fields: [Boolean, Field, String { length: 205 }, String { length: 575 }, Integer { sign: Signed, width: 124 }, String { length: 923 }, String { length: 294 }] })] }, visibility: Public }], return_type: None, error_types: {} }, {"¡": Struct({"$d6🕴/:|�\u{37f8b}\r\u{a13b7}C$𲁹\\&\u{f8712}?\u{db61c}t%\u{57be1}\0": Field(-8275115097504119425402713293372777967031130481426075481525511323101167533940), ".Ⱥ/$\u{7f}\u{103c06}%\\\u{202e}][0\u{88479}]\"*~\u{36fd5}\u{5}\u{feff}]{/": Vec([String("A \0A 0 aA0 a0aa00 A\000 0 \0\0aA\0\0a \0 \0a 0A\0A\0 Aa0aAA0A\0aa\00 0\0\0\0\0\00a Aa0 \0 a A0 \0AA0A Aa Aa\00aAaAaaA0A0 aA0 \0 Aa\00 \0000AAA a \0AAaaA\0\0a A0a0AA\0aA00 aA a0A\0AAa0a\0A0a\0\0A0A \00Aaaaa a A AO.*D\r.`bD4a\n*\u{15}\\B\"ace.8&A\t[AV8w<\u{18}\"\u{f}4`^Q\u{1b}U*$Z/\0\u{b}]qw${`\"=X&A\\\u{e}%`\\:\"$\u{1}.(6_C:\u{7}a`V=N**\u{1b})#Y\u{7f}#\u{b}$l\t}.Mns5!\t*$g\u{18}\rC\u{11}\"$=\u{7}.?&\u{1}yW\t.Y|<6\u{12}\u{e}/4JJ*&/V$`\"&`x#R\np\\%'*\n:P\0K\u{b}*`\r7Ym\t_\u{b}=$\u{16}`0v\u{7f}'NV^N4J<9=G*A:!b\u{1c}:'c{ST&z![\u{7f}/.={E*pmaWC\u{7f}7p{<\"']\u{8}?`\u{1b}\"\\\u{1}$\u{18}/!\u{16}-\t:E7CUs%_qw*xf.S\t\u{4}'=\"&%t'\u{1f}\u{7f}\u{b}$.=f\u{6}\"$A}xV_$\u{1a}nH\n\u{1b}?<&\n\u{15}U\\-b\u{1d}|\u{b}\u{2}t \rwA{L\u{11}\u{6}\u{10}\0\u{1b}G[x?&Yi?&7\u{b}?\r\u{1f}b\\$=\u{b}x& Q/\t\u{4}|X\"7\"{\0\0j'.\0\\e1zR.\u{c}\n<\u{b}Q*R+y8\u{19}(o\u{1f}@m\nt+\u{7f}Q\\+.Rn?\u{17}UZ\"$\u{b}/\0B=9=\t{\u{8}qZ&`!:D{\u{6}IO.H\u{7f}:?/3@\r\u{1b}oä\u{202e}f:BȺ?:``*·¥\u{74ca5}\"": Vec([Field(1), Field(8822392870083219098626030699076694602179106416928939583840848325203494062169), String("*TXn;{}\"_)_9\nk\\#ts\u{10}%\\c\n/2._::Oj*\u{7f}\0\r&PUMl\u{10}$/u?L}\u{7f}*P&<%=\u{7}S#%A\n \u{e}\\#v!\"\nepRp.{vH{&@\t\u{1f}\u{b}?=T\u{f}\"B\u{11}\n/{HY.\u{16}\n\nj<&\u{3}{f\n/9J*&x.$/,\r\0\u{1c}'\u{5}\u{13}\u{1b}`T\0`\n&/&\u{15}\u{b}w:{SK\u{7f}\\apR%/'0`0\n'd$$\u{7f}Vs\t<{\nDTT\\F\n\u{15}y.\\\t*-)&D$*u\u{b}\u{1b}?{\u{b}/\n\u{7f}0*.7\0\n:\u{b}.rSk<6~>{#"), String(".\"JA%q6i\ra/:F\u{16}?q<\t\rN\\13?H<;?{`\u{1d}p{.\"5?*@'N\"\u{1a}P,\u{1b}\u{7f}c+dt5':Y\u{1b}k/G>k/eM$XIX')\u{1b}'&\u{7f}\\\r\u{1b}`'P_.\n.?\0p`Y\u{c}`._\u{b}B\0\ng/*v$jfJ:\u{c}\u{1b}Pv}xn7ph@#{_<{.JD?r%'E\n7s9n/],u![;%*\u{2}{y`MgRdok8\"%<*>*{GyFJ}?\0W%#\0\u{1b}\u{7f}\u{16}G:\t=w\u{7f}:q\u{7f}:{k?\u{b}(:ca{$*1X/cw\u{1b}Z6I\rX\0\u{1b}(.^14\r\\=s\u{1b}w\u{3}F~\n\u{1e})/$0:=[\u{1},\\\\\tg\u{16}:],J`\0N\n\u{1b}\u{1b}\u{1b}{.xb\u{1a}\r'12#?e\\#/\tA\u{7f}\".\\Ke=\\?!v+P\u{17}\r\u{12}x.=A.`0<&?\niR/*WW\rnV)5vY.~\n _h\0&5f#\r\u{2}-S%\t s..\u{7f}!X}\"=\"?\u{5}y\u{4}`fr&R&d: 1Ht\"4`y_/S.71#{|%$%&ehy\u{16}J_\u{e}=:.%'\"N=J:\r:{&.\u{12}\u{b})&N\u{10}R_3;11\u{b}Qd<`<{?xF:~\"%<=<<\03:t??&\r;{\u{13}?__Y\u{6})\\k,vs?\n`G(*\n!\u{1b}[@z\0$?*yKLJh_\u{13}FkY'\\?T^\u{1f}$1n`'[\n\u{7f}\0+l\u{b}\u{1a}E\u{b}&(/\u{b}\rr\t:&\0+N'N:oC:*``IN\u{b}*.:\t$7+'*U:\t Result { let json_value = match (value, abi_type) { + (InputValue::Field(f), AbiType::Integer { sign: crate::Sign::Signed, width }) => { + JsonTypes::String(field_to_signed_hex(*f, *width)) + } (InputValue::Field(f), AbiType::Field | AbiType::Integer { .. }) => { JsonTypes::String(Self::format_field_string(*f)) } @@ -143,6 +146,9 @@ impl InputValue { ) -> Result { let input_value = match (value, param_type) { (JsonTypes::String(string), AbiType::String { .. }) => InputValue::String(string), + (JsonTypes::String(string), AbiType::Integer { sign: crate::Sign::Signed, width }) => { + InputValue::Field(parse_str_to_signed(&string, *width)?) + } ( JsonTypes::String(string), AbiType::Field | AbiType::Integer { .. } | AbiType::Boolean, @@ -192,3 +198,40 @@ impl InputValue { Ok(input_value) } } + +#[cfg(test)] +mod test { + use proptest::prelude::*; + + use crate::{ + arbitrary::arb_abi_and_input_map, + input_parser::{arbitrary::arb_signed_integer_type_and_value, json::JsonTypes, InputValue}, + }; + + use super::{parse_json, serialize_to_json}; + + proptest! { + #[test] + fn serializing_and_parsing_returns_original_input((abi, input_map) in arb_abi_and_input_map()) { + let json = serialize_to_json(&input_map, &abi).expect("should be serializable"); + let parsed_input_map = parse_json(&json, &abi).expect("should be parsable"); + + prop_assert_eq!(parsed_input_map, input_map); + } + + #[test] + fn signed_integer_serialization_roundtrip((typ, value) in arb_signed_integer_type_and_value()) { + let string_input = JsonTypes::String(value.to_string()); + let input_value = InputValue::try_from_json(string_input, &typ, "foo").expect("should be parsable"); + let JsonTypes::String(output_string) = JsonTypes::try_from_input_value(&input_value, &typ).expect("should be serializable") else { + panic!("wrong type output"); + }; + let output_number = if let Some(output_string) = output_string.strip_prefix("-0x") { + -i64::from_str_radix(output_string, 16).unwrap() + } else { + i64::from_str_radix(output_string.strip_prefix("0x").unwrap(), 16).unwrap() + }; + prop_assert_eq!(output_number, value); + } + } +} diff --git a/tooling/noirc_abi/src/input_parser/mod.rs b/tooling/noirc_abi/src/input_parser/mod.rs index d7bbb0adfe3..b7732235eb2 100644 --- a/tooling/noirc_abi/src/input_parser/mod.rs +++ b/tooling/noirc_abi/src/input_parser/mod.rs @@ -248,6 +248,11 @@ mod serialization_tests { typ: AbiType::Field, visibility: AbiVisibility::Private, }, + AbiParameter { + name: "signed_example".into(), + typ: AbiType::Integer { sign: Sign::Signed, width: 8 }, + visibility: AbiVisibility::Private, + }, AbiParameter { name: "bar".into(), typ: AbiType::Struct { @@ -272,6 +277,7 @@ mod serialization_tests { let input_map: BTreeMap = BTreeMap::from([ ("foo".into(), InputValue::Field(FieldElement::one())), + ("signed_example".into(), InputValue::Field(FieldElement::from(240u128))), ( "bar".into(), InputValue::Struct(BTreeMap::from([ @@ -317,7 +323,9 @@ fn parse_str_to_field(value: &str) -> Result { } fn parse_str_to_signed(value: &str, width: u32) -> Result { - let big_num = if let Some(hex) = value.strip_prefix("0x") { + let big_num = if let Some(hex) = value.strip_prefix("-0x") { + BigInt::from_str_radix(hex, 16).map(|value| -value) + } else if let Some(hex) = value.strip_prefix("0x") { BigInt::from_str_radix(hex, 16) } else { BigInt::from_str_radix(value, 10) @@ -357,12 +365,23 @@ fn field_from_big_int(bigint: BigInt) -> FieldElement { } } +fn field_to_signed_hex(f: FieldElement, bit_size: u32) -> String { + let f_u128 = f.to_u128(); + let max = 2_u128.pow(bit_size - 1) - 1; + if f_u128 > max { + let f = FieldElement::from(2_u128.pow(bit_size) - f_u128); + format!("-0x{}", f.to_hex()) + } else { + format!("0x{}", f.to_hex()) + } +} + #[cfg(test)] mod test { use acvm::{AcirField, FieldElement}; use num_bigint::BigUint; - use super::parse_str_to_field; + use super::{parse_str_to_field, parse_str_to_signed}; fn big_uint_from_field(field: FieldElement) -> BigUint { BigUint::from_bytes_be(&field.to_be_bytes()) @@ -400,4 +419,38 @@ mod test { let noncanonical_field = FieldElement::modulus().to_string(); assert!(parse_str_to_field(&noncanonical_field).is_err()); } + + #[test] + fn test_parse_str_to_signed() { + let value = parse_str_to_signed("1", 8).unwrap(); + assert_eq!(value, FieldElement::from(1_u128)); + + let value = parse_str_to_signed("-1", 8).unwrap(); + assert_eq!(value, FieldElement::from(255_u128)); + + let value = parse_str_to_signed("-1", 16).unwrap(); + assert_eq!(value, FieldElement::from(65535_u128)); + } +} + +#[cfg(test)] +mod arbitrary { + use proptest::prelude::*; + + use crate::{AbiType, Sign}; + + pub(super) fn arb_signed_integer_type_and_value() -> BoxedStrategy<(AbiType, i64)> { + (2u32..=64) + .prop_flat_map(|width| { + let typ = Just(AbiType::Integer { width, sign: Sign::Signed }); + let value = if width == 64 { + // Avoid overflow + i64::MIN..i64::MAX + } else { + -(2i64.pow(width - 1))..(2i64.pow(width - 1) - 1) + }; + (typ, value) + }) + .boxed() + } } diff --git a/tooling/noirc_abi/src/input_parser/toml.rs b/tooling/noirc_abi/src/input_parser/toml.rs index 321d3511b5d..6f2be68a0c4 100644 --- a/tooling/noirc_abi/src/input_parser/toml.rs +++ b/tooling/noirc_abi/src/input_parser/toml.rs @@ -1,4 +1,4 @@ -use super::{parse_str_to_field, parse_str_to_signed, InputValue}; +use super::{field_to_signed_hex, parse_str_to_field, parse_str_to_signed, InputValue}; use crate::{errors::InputParserError, Abi, AbiType, MAIN_RETURN_NAME}; use acvm::{AcirField, FieldElement}; use iter_extended::{try_btree_map, try_vecmap}; @@ -60,7 +60,7 @@ pub(crate) fn serialize_to_toml( Ok(toml_string) } -#[derive(Debug, Deserialize, Serialize, Clone)] +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] #[serde(untagged)] enum TomlTypes { // This is most likely going to be a hex string @@ -83,6 +83,9 @@ impl TomlTypes { abi_type: &AbiType, ) -> Result { let toml_value = match (value, abi_type) { + (InputValue::Field(f), AbiType::Integer { sign: crate::Sign::Signed, width }) => { + TomlTypes::String(field_to_signed_hex(*f, *width)) + } (InputValue::Field(f), AbiType::Field | AbiType::Integer { .. }) => { let f_str = format!("0x{}", f.to_hex()); TomlTypes::String(f_str) @@ -126,6 +129,7 @@ impl InputValue { ) -> Result { let input_value = match (value, param_type) { (TomlTypes::String(string), AbiType::String { .. }) => InputValue::String(string), + ( TomlTypes::String(string), AbiType::Field @@ -139,7 +143,7 @@ impl InputValue { TomlTypes::Integer(integer), AbiType::Field | AbiType::Integer { .. } | AbiType::Boolean, ) => { - let new_value = FieldElement::from(i128::from(integer)); + let new_value = FieldElement::from(u128::from(integer)); InputValue::Field(new_value) } @@ -179,3 +183,40 @@ impl InputValue { Ok(input_value) } } + +#[cfg(test)] +mod test { + use proptest::prelude::*; + + use crate::{ + arbitrary::arb_abi_and_input_map, + input_parser::{arbitrary::arb_signed_integer_type_and_value, toml::TomlTypes, InputValue}, + }; + + use super::{parse_toml, serialize_to_toml}; + + proptest! { + #[test] + fn serializing_and_parsing_returns_original_input((abi, input_map) in arb_abi_and_input_map()) { + let toml = serialize_to_toml(&input_map, &abi).expect("should be serializable"); + let parsed_input_map = parse_toml(&toml, &abi).expect("should be parsable"); + + prop_assert_eq!(parsed_input_map, input_map); + } + + #[test] + fn signed_integer_serialization_roundtrip((typ, value) in arb_signed_integer_type_and_value()) { + let string_input = TomlTypes::String(value.to_string()); + let input_value = InputValue::try_from_toml(string_input.clone(), &typ, "foo").expect("should be parsable"); + let TomlTypes::String(output_string) = TomlTypes::try_from_input_value(&input_value, &typ).expect("should be serializable") else { + panic!("wrong type output"); + }; + let output_number = if let Some(output_string) = output_string.strip_prefix("-0x") { + -i64::from_str_radix(output_string, 16).unwrap() + } else { + i64::from_str_radix(output_string.strip_prefix("0x").unwrap(), 16).unwrap() + }; + prop_assert_eq!(output_number, value); + } + } +} diff --git a/tooling/noirc_abi/src/lib.rs b/tooling/noirc_abi/src/lib.rs index b1b199727c2..bd5674d64f1 100644 --- a/tooling/noirc_abi/src/lib.rs +++ b/tooling/noirc_abi/src/lib.rs @@ -49,6 +49,7 @@ pub const MAIN_RETURN_NAME: &str = "return"; /// depends on the types of programs that users want to do. I don't envision string manipulation /// in programs, however it is possible to support, with many complications like encoding character set /// support. +#[derive(Hash)] pub enum AbiType { Field, Array { @@ -77,7 +78,7 @@ pub enum AbiType { }, } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] #[cfg_attr(test, derive(arbitrary::Arbitrary))] #[serde(rename_all = "lowercase")] /// Represents whether the parameter is public or known only to the prover. @@ -89,7 +90,7 @@ pub enum AbiVisibility { DataBus, } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] #[cfg_attr(test, derive(arbitrary::Arbitrary))] #[serde(rename_all = "lowercase")] pub enum Sign { @@ -146,7 +147,7 @@ impl From<&AbiType> for PrintableType { } } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)] #[cfg_attr(test, derive(arbitrary::Arbitrary))] /// An argument or return value of the circuit's `main` function. pub struct AbiParameter { @@ -163,7 +164,7 @@ impl AbiParameter { } } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, Hash)] #[cfg_attr(test, derive(arbitrary::Arbitrary))] pub struct AbiReturnType { #[cfg_attr(test, proptest(strategy = "arbitrary::arb_abi_type()"))] @@ -171,7 +172,7 @@ pub struct AbiReturnType { pub visibility: AbiVisibility, } -#[derive(Clone, Debug, Default, Serialize, Deserialize)] +#[derive(Clone, Debug, Default, Serialize, Deserialize, Hash)] #[cfg_attr(test, derive(arbitrary::Arbitrary))] pub struct Abi { /// An ordered list of the arguments to the program's `main` function, specifying their types and visibility. @@ -459,7 +460,7 @@ pub enum AbiValue { }, } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] #[serde(tag = "error_kind", rename_all = "lowercase")] pub enum AbiErrorType { FmtString { length: u32, item_types: Vec }, diff --git a/tooling/noirc_abi_wasm/Cargo.toml b/tooling/noirc_abi_wasm/Cargo.toml index daa619ca01d..b00d580515e 100644 --- a/tooling/noirc_abi_wasm/Cargo.toml +++ b/tooling/noirc_abi_wasm/Cargo.toml @@ -1,9 +1,11 @@ [package] name = "noirc_abi_wasm" +description = "An ABI encoder for the Noir language" version.workspace = true authors.workspace = true edition.workspace = true license.workspace = true +repository.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/tooling/noirc_abi_wasm/test/browser/abi_encode.test.ts b/tooling/noirc_abi_wasm/test/browser/abi_encode.test.ts index e1aaf0dc2c0..ac18495919c 100644 --- a/tooling/noirc_abi_wasm/test/browser/abi_encode.test.ts +++ b/tooling/noirc_abi_wasm/test/browser/abi_encode.test.ts @@ -15,7 +15,8 @@ it('recovers original inputs when abi encoding and decoding', async () => { const foo: Field = inputs.foo as Field; const bar: Field[] = inputs.bar as Field[]; expect(BigInt(decoded_inputs.inputs.foo)).to.be.equal(BigInt(foo)); - expect(BigInt(decoded_inputs.inputs.bar[0])).to.be.equal(BigInt(bar[0])); - expect(BigInt(decoded_inputs.inputs.bar[1])).to.be.equal(BigInt(bar[1])); + expect(parseInt(decoded_inputs.inputs.bar[0])).to.be.equal(parseInt(bar[0].toString())); + expect(parseInt(decoded_inputs.inputs.bar[1])).to.be.equal(parseInt(bar[1].toString())); + expect(parseInt(decoded_inputs.inputs.bar[2])).to.be.equal(parseInt(bar[2].toString())); expect(decoded_inputs.return_value).to.be.null; }); diff --git a/tooling/noirc_abi_wasm/test/node/abi_encode.test.ts b/tooling/noirc_abi_wasm/test/node/abi_encode.test.ts index a49c10b6ea6..e87618d84da 100644 --- a/tooling/noirc_abi_wasm/test/node/abi_encode.test.ts +++ b/tooling/noirc_abi_wasm/test/node/abi_encode.test.ts @@ -11,7 +11,8 @@ it('recovers original inputs when abi encoding and decoding', async () => { const foo: Field = inputs.foo as Field; const bar: Field[] = inputs.bar as Field[]; expect(BigInt(decoded_inputs.inputs.foo)).to.be.equal(BigInt(foo)); - expect(BigInt(decoded_inputs.inputs.bar[0])).to.be.equal(BigInt(bar[0])); - expect(BigInt(decoded_inputs.inputs.bar[1])).to.be.equal(BigInt(bar[1])); + expect(parseInt(decoded_inputs.inputs.bar[0])).to.be.equal(parseInt(bar[0].toString())); + expect(parseInt(decoded_inputs.inputs.bar[1])).to.be.equal(parseInt(bar[1].toString())); + expect(parseInt(decoded_inputs.inputs.bar[2])).to.be.equal(parseInt(bar[2].toString())); expect(decoded_inputs.return_value).to.be.null; }); diff --git a/tooling/noirc_abi_wasm/test/shared/abi_encode.ts b/tooling/noirc_abi_wasm/test/shared/abi_encode.ts index 62eb7658f43..b789bb05371 100644 --- a/tooling/noirc_abi_wasm/test/shared/abi_encode.ts +++ b/tooling/noirc_abi_wasm/test/shared/abi_encode.ts @@ -5,7 +5,7 @@ export const abi: Abi = { { name: 'foo', type: { kind: 'field' }, visibility: 'private' }, { name: 'bar', - type: { kind: 'array', length: 2, type: { kind: 'field' } }, + type: { kind: 'array', length: 3, type: { kind: 'integer', sign: 'signed', width: 32 } }, visibility: 'private', }, ], @@ -15,5 +15,5 @@ export const abi: Abi = { export const inputs: InputMap = { foo: '1', - bar: ['1', '2'], + bar: ['1', '2', '-1'], }; diff --git a/tooling/profiler/src/cli/execution_flamegraph_cmd.rs b/tooling/profiler/src/cli/execution_flamegraph_cmd.rs index 981d08a3eb1..6d6da89f660 100644 --- a/tooling/profiler/src/cli/execution_flamegraph_cmd.rs +++ b/tooling/profiler/src/cli/execution_flamegraph_cmd.rs @@ -8,7 +8,7 @@ use crate::flamegraph::{BrilligExecutionSample, FlamegraphGenerator, InfernoFlam use crate::fs::{read_inputs_from_file, read_program_from_file}; use crate::opcode_formatter::format_brillig_opcode; use bn254_blackbox_solver::Bn254BlackBoxSolver; -use nargo::ops::DefaultForeignCallExecutor; +use nargo::foreign_calls::DefaultForeignCallExecutor; use noirc_abi::input_parser::Format; use noirc_artifacts::debug::DebugArtifact; diff --git a/tooling/profiler/src/cli/gates_flamegraph_cmd.rs b/tooling/profiler/src/cli/gates_flamegraph_cmd.rs index c3ae29de058..e68a8cd5bd2 100644 --- a/tooling/profiler/src/cli/gates_flamegraph_cmd.rs +++ b/tooling/profiler/src/cli/gates_flamegraph_cmd.rs @@ -31,6 +31,10 @@ pub(crate) struct GatesFlamegraphCommand { /// The output folder for the flamegraph svg files #[clap(long, short)] output: String, + + /// The output name for the flamegraph svg files + #[clap(long, short = 'f')] + output_filename: Option, } pub(crate) fn run(args: GatesFlamegraphCommand) -> eyre::Result<()> { @@ -43,6 +47,7 @@ pub(crate) fn run(args: GatesFlamegraphCommand) -> eyre::Result<()> { }, &InfernoFlamegraphGenerator { count_name: "gates".to_string() }, &PathBuf::from(args.output), + args.output_filename, ) } @@ -51,6 +56,7 @@ fn run_with_provider( gates_provider: &Provider, flamegraph_generator: &Generator, output_path: &Path, + output_filename: Option, ) -> eyre::Result<()> { let mut program = read_program_from_file(artifact_path).context("Error reading program from file")?; @@ -91,13 +97,18 @@ fn run_with_provider( }) .collect(); + let output_filename = if let Some(output_filename) = &output_filename { + format!("{}::{}::gates.svg", output_filename, func_name) + } else { + format!("{}::gates.svg", func_name) + }; flamegraph_generator.generate_flamegraph( samples, &debug_artifact.debug_symbols[func_idx], &debug_artifact, artifact_path.to_str().unwrap(), &func_name, - &Path::new(&output_path).join(Path::new(&format!("{}_gates.svg", &func_name))), + &Path::new(&output_path).join(Path::new(&output_filename)), )?; } @@ -189,11 +200,17 @@ mod tests { }; let flamegraph_generator = TestFlamegraphGenerator::default(); - super::run_with_provider(&artifact_path, &provider, &flamegraph_generator, temp_dir.path()) - .expect("should run without errors"); + super::run_with_provider( + &artifact_path, + &provider, + &flamegraph_generator, + temp_dir.path(), + Some(String::from("test_filename")), + ) + .expect("should run without errors"); // Check that the output file was written to - let output_file = temp_dir.path().join("main_gates.svg"); + let output_file = temp_dir.path().join("test_filename::main::gates.svg"); assert!(output_file.exists()); } }