diff --git a/rust-toolchain b/rust-toolchain index 2bf5ad0447..f4ea31cf76 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -stable +nightly-2024-05-01 \ No newline at end of file diff --git a/stark-middleware/src/verifier/error.rs b/stark-middleware/src/verifier/error.rs index c9afebea7b..102be33586 100644 --- a/stark-middleware/src/verifier/error.rs +++ b/stark-middleware/src/verifier/error.rs @@ -1,4 +1,4 @@ -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub enum VerificationError { InvalidProofShape, /// An error occurred while verifying the claimed openings. diff --git a/stark-middleware/tests/fib_selector_air/air.rs b/stark-middleware/tests/fib_selector_air/air.rs index 9f84a2c50a..2a34b0887a 100644 --- a/stark-middleware/tests/fib_selector_air/air.rs +++ b/stark-middleware/tests/fib_selector_air/air.rs @@ -2,17 +2,29 @@ use std::borrow::Borrow; use super::columns::FibonacciSelectorCols; use crate::fib_air::columns::{FibonacciCols, NUM_FIBONACCI_COLS}; -use afs_middleware::interaction::Chip; -use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir, PairBuilder}; +use afs_middleware::interaction::{Chip, Interaction}; +use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir, PairBuilder, VirtualPairCol}; use p3_field::{AbstractField, Field}; use p3_matrix::{dense::RowMajorMatrix, Matrix}; pub struct FibonacciSelectorAir { pub sels: Vec, + pub enable_interactions: bool, } -// No interactions -impl Chip for FibonacciSelectorAir {} +impl Chip for FibonacciSelectorAir { + fn receives(&self) -> Vec> { + if self.enable_interactions { + vec![Interaction:: { + fields: vec![VirtualPairCol::::sum_main(vec![0, 1])], + count: VirtualPairCol::::single_preprocessed(0), + argument_index: 0, + }] + } else { + vec![] + } + } +} impl BaseAir for FibonacciSelectorAir { fn width(&self) -> usize { diff --git a/stark-middleware/tests/integration_test.rs b/stark-middleware/tests/integration_test.rs index 0cb14b53a7..c224036ba4 100644 --- a/stark-middleware/tests/integration_test.rs +++ b/stark-middleware/tests/integration_test.rs @@ -1,7 +1,14 @@ +#![feature(trait_upcasting)] +#![allow(incomplete_features)] + use afs_middleware::{ - prover::{trace::TraceCommitter, types::ProvenMultiMatrixAirTrace, PartitionProver}, + prover::{ + trace::TraceCommitter, + types::{ProvenMultiMatrixAirTrace, ProverRap}, + PartitionProver, + }, setup::PartitionSetup, - verifier::PartitionVerifier, + verifier::{types::VerifierRap, PartitionVerifier}, }; use p3_air::BaseAir; use p3_baby_bear::BabyBear; @@ -18,6 +25,10 @@ use crate::config::poseidon2::StarkConfigPoseidon2; mod config; mod fib_air; mod fib_selector_air; +mod interaction; + +trait ProverVerifierRap: ProverRap + VerifierRap {} +impl + VerifierRap> ProverVerifierRap for RAP {} #[test] fn test_single_fib_stark() { @@ -99,7 +110,10 @@ fn test_single_fib_selector_stark() { let sels: Vec = (0..n).map(|i| i % 2 == 0).collect(); let pis = [a, b, get_conditional_fib_number(&sels)].map(BabyBear::from_canonical_u32); - let air = FibonacciSelectorAir { sels }; + let air = FibonacciSelectorAir { + sels, + enable_interactions: false, + }; let prep_trace = air.preprocessed_trace(); let setup = PartitionSetup::new(&config); @@ -154,7 +168,10 @@ fn test_double_fib_starks() { let pis = [a, b, get_fib_number(n)].map(BabyBear::from_canonical_u32); let air1 = FibonacciAir {}; - let air2 = FibonacciSelectorAir { sels }; + let air2 = FibonacciSelectorAir { + sels, + enable_interactions: false, + }; let prep_trace1 = air1.preprocessed_trace(); let prep_trace2 = air2.preprocessed_trace(); diff --git a/stark-middleware/tests/interaction/dummy_interaction_air.rs b/stark-middleware/tests/interaction/dummy_interaction_air.rs new file mode 100644 index 0000000000..e097d57c6d --- /dev/null +++ b/stark-middleware/tests/interaction/dummy_interaction_air.rs @@ -0,0 +1,73 @@ +use afs_middleware::interaction::{Chip, Interaction}; +use afs_middleware_derive::AlignedBorrow; +use core::mem::size_of; +use p3_air::{Air, AirBuilderWithPublicValues, BaseAir, PairBuilder, VirtualPairCol}; +use p3_field::Field; +use p3_matrix::dense::RowMajorMatrix; +use p3_util::indices_arr; +use std::mem::transmute; + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct DummyInteractionCols { + pub count: F, + pub val: F, +} + +const NUM_DUMMY_INTERACTION_COLS: usize = size_of::>(); +const DUMMY_INTERACTION_COL_MAP: DummyInteractionCols = make_col_map(); + +const fn make_col_map() -> DummyInteractionCols { + let indices_arr = indices_arr::(); + unsafe { + transmute::<[usize; NUM_DUMMY_INTERACTION_COLS], DummyInteractionCols>(indices_arr) + } +} + +pub struct DummyInteractionAir { + // Send if true. Receive if false. + pub is_send: bool, +} + +impl Chip for DummyInteractionAir { + fn sends(&self) -> Vec> { + if self.is_send { + vec![Interaction:: { + fields: vec![VirtualPairCol::::single_main( + DUMMY_INTERACTION_COL_MAP.val, + )], + count: VirtualPairCol::::single_main(DUMMY_INTERACTION_COL_MAP.count), + argument_index: 0, + }] + } else { + vec![] + } + } + fn receives(&self) -> Vec> { + if !self.is_send { + vec![Interaction:: { + fields: vec![VirtualPairCol::::single_main( + DUMMY_INTERACTION_COL_MAP.val, + )], + count: VirtualPairCol::::single_main(DUMMY_INTERACTION_COL_MAP.count), + argument_index: 0, + }] + } else { + vec![] + } + } +} + +impl BaseAir for DummyInteractionAir { + fn width(&self) -> usize { + 2 + } + + fn preprocessed_trace(&self) -> Option> { + None + } +} + +impl Air for DummyInteractionAir { + fn eval(&self, _builder: &mut AB) {} +} diff --git a/stark-middleware/tests/interaction/mod.rs b/stark-middleware/tests/interaction/mod.rs new file mode 100644 index 0000000000..93a4a0336d --- /dev/null +++ b/stark-middleware/tests/interaction/mod.rs @@ -0,0 +1,330 @@ +use afs_middleware::{ + prover::{ + trace::TraceCommitter, + types::{ProvenMultiMatrixAirTrace, ProverRap}, + PartitionProver, + }, + setup::PartitionSetup, + verifier::{types::VerifierRap, PartitionVerifier, VerificationError}, +}; +use itertools::Itertools; +use p3_baby_bear::BabyBear; +use p3_field::AbstractField; +use p3_matrix::dense::RowMajorMatrix; +use p3_uni_stark::StarkGenericConfig; +use tracing_forest::util::LevelFilter; +use tracing_forest::ForestLayer; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{EnvFilter, Registry}; + +use crate::{ + config::{self, poseidon2::StarkConfigPoseidon2}, + fib_selector_air::{air::FibonacciSelectorAir, trace::generate_trace_rows}, + get_conditional_fib_number, ProverVerifierRap, +}; + +mod dummy_interaction_air; + +type Val = BabyBear; + +fn verify_interactions( + traces: Vec>, + airs: Vec<&dyn ProverVerifierRap>, + pis: Vec, +) -> Result<(), VerificationError> { + // Set up tracing: + let env_filter = EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(); + let _ = Registry::default() + .with(env_filter) + .with(ForestLayer::default()) + .try_init(); + + let log_trace_degree = 3; + let perm = config::poseidon2::random_perm(); + let config = config::poseidon2::default_config(&perm, log_trace_degree); + + let setup = PartitionSetup::new(&config); + let (pk, vk) = setup.setup(airs.iter().map(|air| air.preprocessed_trace()).collect()); + + let trace_committer = TraceCommitter::::new(config.pcs()); + let proven_trace = trace_committer.commit(traces); + + let proven = ProvenMultiMatrixAirTrace { + trace_data: &proven_trace, + airs: airs + .iter() + .map(|&air| air as &dyn ProverRap) + .collect(), + }; + + let prover = PartitionProver::new(config); + let mut challenger = config::poseidon2::Challenger::new(perm.clone()); + let proof = prover.prove(&mut challenger, &pk, vec![proven], &pis); + + // Verify the proof: + // Start from clean challenger + let mut challenger = config::poseidon2::Challenger::new(perm.clone()); + let verifier = PartitionVerifier::new(prover.config); + verifier.verify( + &mut challenger, + vk, + airs.iter() + .map(|&air| air as &dyn VerifierRap) + .collect(), + proof, + &pis, + ) +} + +#[test] +fn test_interaction_fib_selector_happy_path() { + let log_trace_degree = 3; + + // Public inputs: + let a = 0u32; + let b = 1u32; + let n = 1usize << log_trace_degree; + + let sels: Vec = (0..n).map(|i| i % 2 == 0).collect(); + let fib_res = get_conditional_fib_number(&sels); + let pis = vec![a, b, fib_res] + .into_iter() + .map(Val::from_canonical_u32) + .collect_vec(); + + let air = FibonacciSelectorAir { + sels: sels.clone(), + enable_interactions: true, + }; + let trace = generate_trace_rows::(a, b, &sels); + + let mut curr_a = a; + let mut curr_b = b; + let mut vals = vec![]; + for sel in sels { + vals.push(Val::from_bool(sel)); + if sel { + let c = curr_a + curr_b; + curr_a = curr_b; + curr_b = c; + } + vals.push(Val::from_canonical_u32(curr_b)); + } + let sender_trace = RowMajorMatrix::new(vals, 2); + let sender_air = dummy_interaction_air::DummyInteractionAir { is_send: true }; + verify_interactions(vec![trace, sender_trace], vec![&air, &sender_air], pis) + .expect("Verification failed"); +} + +fn to_field_vec(v: Vec) -> Vec { + v.into_iter().map(Val::from_canonical_u32).collect() +} + +#[test] +fn test_interaction_stark_multi_rows_happy_path() { + // Mul Val + // 0 1 + // 7 4 + // 3 5 + // 546 889 + let sender_trace = RowMajorMatrix::new(to_field_vec(vec![0, 1, 3, 5, 7, 4, 546, 889]), 2); + let sender_air = dummy_interaction_air::DummyInteractionAir { is_send: true }; + + // Mul Val + // 1 5 + // 3 4 + // 4 4 + // 2 5 + // 0 123 + // 545 889 + // 1 889 + // 0 456 + let receiver_trace = RowMajorMatrix::new( + to_field_vec(vec![ + 1, 5, 3, 4, 4, 4, 2, 5, 0, 123, 545, 889, 1, 889, 0, 456, + ]), + 2, + ); + let receiver_air = dummy_interaction_air::DummyInteractionAir { is_send: false }; + verify_interactions( + vec![sender_trace, receiver_trace], + vec![&sender_air, &receiver_air], + vec![], + ) + .expect("Verification failed"); +} + +#[test] +fn test_interaction_stark_multi_rows_neg() { + // Mul Val + // 0 1 + // 3 5 + // 7 4 + // 546 0 + let sender_trace = RowMajorMatrix::new(to_field_vec(vec![0, 1, 3, 5, 7, 4, 546, 0]), 2); + let sender_air = dummy_interaction_air::DummyInteractionAir { is_send: true }; + + // count of 0 is 545 != 546 in send. + // Mul Val + // 1 5 + // 3 4 + // 4 4 + // 2 5 + // 0 123 + // 545 0 + // 0 0 + // 0 456 + let receiver_trace = RowMajorMatrix::new( + to_field_vec(vec![1, 5, 3, 4, 4, 4, 2, 5, 0, 123, 545, 0, 0, 0, 0, 456]), + 2, + ); + let receiver_air = dummy_interaction_air::DummyInteractionAir { is_send: false }; + let res = verify_interactions( + vec![sender_trace, receiver_trace], + vec![&sender_air, &receiver_air], + vec![], + ); + assert_eq!(res, Err(VerificationError::NonZeroCumulativeSum)); +} + +#[test] +fn test_interaction_stark_all_0_sender_happy_path() { + // Mul Val + // 0 1 + // 0 646 + // 0 0 + // 0 589 + let sender_trace = RowMajorMatrix::new(to_field_vec(vec![0, 1, 0, 5, 0, 4, 0, 889]), 2); + let sender_air = dummy_interaction_air::DummyInteractionAir { is_send: true }; + verify_interactions(vec![sender_trace], vec![&sender_air], vec![]) + .expect("Verification failed"); +} + +#[test] +fn test_interaction_stark_multi_senders_happy_path() { + // Mul Val + // 0 1 + // 6 4 + // 3 5 + // 333 889 + let sender_trace1 = RowMajorMatrix::new(to_field_vec(vec![0, 1, 3, 5, 6, 4, 333, 889]), 2); + // Mul Val + // 1 4 + // 213 889 + let sender_trace2 = RowMajorMatrix::new(to_field_vec(vec![1, 4, 213, 889]), 2); + + let sender_air = dummy_interaction_air::DummyInteractionAir { is_send: true }; + + // Mul Val + // 1 5 + // 3 4 + // 4 4 + // 2 5 + // 0 123 + // 545 889 + // 1 889 + // 0 456 + let receiver_trace = RowMajorMatrix::new( + to_field_vec(vec![ + 1, 5, 3, 4, 4, 4, 2, 5, 0, 123, 545, 889, 1, 889, 0, 456, + ]), + 2, + ); + let receiver_air = dummy_interaction_air::DummyInteractionAir { is_send: false }; + verify_interactions( + vec![sender_trace1, sender_trace2, receiver_trace], + vec![&sender_air, &sender_air, &receiver_air], + vec![], + ) + .expect("Verification failed"); +} + +#[test] +fn test_interaction_stark_multi_senders_neg() { + // Mul Val + // 0 1 + // 5 4 + // 3 5 + // 333 889 + let sender_trace1 = RowMajorMatrix::new(to_field_vec(vec![0, 1, 3, 5, 5, 4, 333, 889]), 2); + // Mul Val + // 1 4 + // 213 889 + let sender_trace2 = RowMajorMatrix::new(to_field_vec(vec![1, 4, 213, 889]), 2); + + let sender_air = dummy_interaction_air::DummyInteractionAir { is_send: true }; + + // Mul Val + // 1 5 + // 3 4 + // 4 4 + // 2 5 + // 0 123 + // 545 889 + // 1 889 + // 0 456 + let receiver_trace = RowMajorMatrix::new( + to_field_vec(vec![ + 1, 5, 3, 4, 4, 4, 2, 5, 0, 123, 545, 889, 1, 889, 0, 456, + ]), + 2, + ); + let receiver_air = dummy_interaction_air::DummyInteractionAir { is_send: false }; + let res = verify_interactions( + vec![sender_trace1, sender_trace2, receiver_trace], + vec![&sender_air, &sender_air, &receiver_air], + vec![], + ); + assert_eq!(res, Err(VerificationError::NonZeroCumulativeSum)); +} + +#[test] +fn test_interaction_stark_multi_sender_receiver_happy_path() { + // Mul Val + // 0 1 + // 6 4 + // 3 5 + // 333 889 + let sender_trace1 = RowMajorMatrix::new(to_field_vec(vec![0, 1, 3, 5, 6, 4, 333, 889]), 2); + // Mul Val + // 1 4 + // 213 889 + let sender_trace2 = RowMajorMatrix::new(to_field_vec(vec![1, 4, 213, 889]), 2); + + let sender_air = dummy_interaction_air::DummyInteractionAir { is_send: true }; + + // Mul Val + // 1 5 + // 3 4 + // 4 4 + // 2 5 + // 0 123 + // 545 889 + // 0 289 + // 0 456 + let receiver_trace1 = RowMajorMatrix::new( + to_field_vec(vec![ + 1, 5, 3, 4, 4, 4, 2, 5, 0, 123, 545, 889, 0, 289, 0, 456, + ]), + 2, + ); + + // Mul Val + // 1 889 + let receiver_trace2 = RowMajorMatrix::new(to_field_vec(vec![1, 889]), 2); + let receiver_air = dummy_interaction_air::DummyInteractionAir { is_send: false }; + verify_interactions( + vec![ + sender_trace1, + sender_trace2, + receiver_trace1, + receiver_trace2, + ], + vec![&sender_air, &sender_air, &receiver_air, &receiver_air], + vec![], + ) + .expect("Verification failed"); +}