diff --git a/Cargo.toml b/Cargo.toml index 86831ef3..a8b0d2bf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,6 @@ members = [ "keccak", "keccak-air", "lde", - "ldt", "matrix", "merkle-tree", "maybe-rayon", diff --git a/README.md b/README.md index b5aaa24a..94156c2d 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,14 @@ Hashes - [x] Monolith +## Benchmark + +We sometimes use a Keccak AIR to compare Plonky3's performance to other libraries like Plonky2. Several variations are possible here, with different fields and so forth, but here is one example: +``` +RUST_LOG=info cargo run --example prove_baby_bear_keccak --release --features parallel +``` + + ## License Licensed under either of diff --git a/air/src/air.rs b/air/src/air.rs index d5b075ed..db8446d0 100644 --- a/air/src/air.rs +++ b/air/src/air.rs @@ -22,7 +22,7 @@ pub trait Air: BaseAir { pub trait AirBuilder: Sized { type F: Field; - type Expr: AbstractField + type Expr: AbstractField + From + Add + Add @@ -106,51 +106,40 @@ pub trait AirBuilder: Sized { let x = x.into(); self.assert_zero(x.clone() * (x - Self::Expr::one())); } - - fn assert_zero_ext(&mut self, x: I) - where - ExprExt: AbstractExtensionField, - I: Into, - { - for xb in x.into().as_base_slice().iter().cloned() { - self.assert_zero(xb); - } - } - - fn assert_eq_ext(&mut self, x: I1, y: I2) - where - ExprExt: AbstractExtensionField, - I1: Into, - I2: Into, - { - self.assert_zero_ext::(x.into() - y.into()); - } - - fn assert_one_ext(&mut self, x: I) - where - ExprExt: AbstractExtensionField, - I: Into, - { - let xe: ExprExt = x.into(); - let parts = xe.as_base_slice(); - self.assert_one(parts[0].clone()); - for part in &parts[1..] { - self.assert_zero(part.clone()); - } - } } pub trait PairBuilder: AirBuilder { fn preprocessed(&self) -> Self::M; } -pub trait PermutationAirBuilder: AirBuilder { +pub trait ExtensionBuilder: AirBuilder { type EF: ExtensionField; type ExprEF: AbstractExtensionField; type VarEF: Into + Copy; + fn assert_zero_ext(&mut self, x: I) + where + I: Into; + + fn assert_eq_ext(&mut self, x: I1, y: I2) + where + I1: Into, + I2: Into, + { + self.assert_zero_ext(x.into() - y.into()); + } + + fn assert_one_ext(&mut self, x: I) + where + I: Into, + { + self.assert_eq_ext(x, Self::ExprEF::one()) + } +} + +pub trait PermutationAirBuilder: ExtensionBuilder { type MP: MatrixRowSlices; fn permutation(&self) -> Self::MP; @@ -165,21 +154,6 @@ pub struct FilteredAirBuilder<'a, AB: AirBuilder> { condition: AB::Expr, } -impl<'a, AB: PermutationAirBuilder> PermutationAirBuilder for FilteredAirBuilder<'a, AB> { - type EF = AB::EF; - type VarEF = AB::VarEF; - type ExprEF = AB::ExprEF; - type MP = AB::MP; - - fn permutation(&self) -> Self::MP { - self.inner.permutation() - } - - fn permutation_randomness(&self) -> &[Self::EF] { - self.inner.permutation_randomness() - } -} - impl<'a, AB: AirBuilder> AirBuilder for FilteredAirBuilder<'a, AB> { type F = AB::F; type Expr = AB::Expr; @@ -207,6 +181,32 @@ impl<'a, AB: AirBuilder> AirBuilder for FilteredAirBuilder<'a, AB> { } } +impl<'a, AB: ExtensionBuilder> ExtensionBuilder for FilteredAirBuilder<'a, AB> { + type EF = AB::EF; + type VarEF = AB::VarEF; + type ExprEF = AB::ExprEF; + + fn assert_zero_ext(&mut self, x: I) + where + I: Into, + { + self.inner + .assert_zero_ext(x.into() * self.condition.clone()); + } +} + +impl<'a, AB: PermutationAirBuilder> PermutationAirBuilder for FilteredAirBuilder<'a, AB> { + type MP = AB::MP; + + fn permutation(&self) -> Self::MP { + self.inner.permutation() + } + + fn permutation_randomness(&self) -> &[Self::EF] { + self.inner.permutation_randomness() + } +} + #[cfg(test)] mod tests { use p3_matrix::MatrixRowSlices; diff --git a/baby-bear/src/baby_bear.rs b/baby-bear/src/baby_bear.rs index 6df16d33..9b9e71ed 100644 --- a/baby-bear/src/baby_bear.rs +++ b/baby-bear/src/baby_bear.rs @@ -263,10 +263,38 @@ impl TwoAdicField for BabyBear { const TWO_ADICITY: usize = 27; fn two_adic_generator(bits: usize) -> Self { - // TODO: Consider a `match` which may speed this up. assert!(bits <= Self::TWO_ADICITY); - let base = Self::from_canonical_u32(0x1a427a41); // generates the whole 2^TWO_ADICITY group - base.exp_power_of_2(Self::TWO_ADICITY - bits) + match bits { + 0 => Self::one(), + 1 => Self::from_canonical_u32(0x78000000), + 2 => Self::from_canonical_u32(0x67055c21), + 3 => Self::from_canonical_u32(0x5ee99486), + 4 => Self::from_canonical_u32(0xbb4c4e4), + 5 => Self::from_canonical_u32(0x2d4cc4da), + 6 => Self::from_canonical_u32(0x669d6090), + 7 => Self::from_canonical_u32(0x17b56c64), + 8 => Self::from_canonical_u32(0x67456167), + 9 => Self::from_canonical_u32(0x688442f9), + 10 => Self::from_canonical_u32(0x145e952d), + 11 => Self::from_canonical_u32(0x4fe61226), + 12 => Self::from_canonical_u32(0x4c734715), + 13 => Self::from_canonical_u32(0x11c33e2a), + 14 => Self::from_canonical_u32(0x62c3d2b1), + 15 => Self::from_canonical_u32(0x77cad399), + 16 => Self::from_canonical_u32(0x54c131f4), + 17 => Self::from_canonical_u32(0x4cabd6a6), + 18 => Self::from_canonical_u32(0x5cf5713f), + 19 => Self::from_canonical_u32(0x3e9430e8), + 20 => Self::from_canonical_u32(0xba067a3), + 21 => Self::from_canonical_u32(0x18adc27d), + 22 => Self::from_canonical_u32(0x21fd55bc), + 23 => Self::from_canonical_u32(0x4b859b3d), + 24 => Self::from_canonical_u32(0x3bd57996), + 25 => Self::from_canonical_u32(0x4483d85a), + 26 => Self::from_canonical_u32(0x3a26eef8), + 27 => Self::from_canonical_u32(0x1a427a41), + _ => unreachable!("Already asserted that bits <= Self::TWO_ADICITY"), + } } } @@ -402,6 +430,17 @@ mod tests { type F = BabyBear; + #[test] + fn test_baby_bear_two_adicity_generators() { + let base = BabyBear::from_canonical_u32(0x1a427a41); + for bits in 0..=BabyBear::TWO_ADICITY { + assert_eq!( + BabyBear::two_adic_generator(bits), + base.exp_power_of_2(BabyBear::TWO_ADICITY - bits) + ); + } + } + #[test] fn test_baby_bear() { let f = F::from_canonical_u32(100); diff --git a/challenger/src/duplex_challenger.rs b/challenger/src/duplex_challenger.rs index 8caf3a08..528afaa9 100644 --- a/challenger/src/duplex_challenger.rs +++ b/challenger/src/duplex_challenger.rs @@ -1,7 +1,7 @@ use alloc::vec; use alloc::vec::Vec; -use p3_field::PrimeField64; +use p3_field::{ExtensionField, Field, PrimeField64}; use p3_symmetric::CryptographicPermutation; use crate::{CanObserve, CanSample, CanSampleBits, FieldChallenger}; @@ -87,21 +87,24 @@ where } } -impl CanSample for DuplexChallenger +impl CanSample for DuplexChallenger where - F: Copy, + F: Field, + EF: ExtensionField, P: CryptographicPermutation<[F; WIDTH]>, { - fn sample(&mut self) -> F { - // If we have buffered inputs, we must perform a duplexing so that the challenge will - // reflect them. Or if we've run out of outputs, we must perform a duplexing to get more. - if !self.input_buffer.is_empty() || self.output_buffer.is_empty() { - self.duplexing(); - } - - self.output_buffer - .pop() - .expect("Output buffer should be non-empty") + fn sample(&mut self) -> EF { + EF::from_base_fn(|_| { + // If we have buffered inputs, we must perform a duplexing so that the challenge will + // reflect them. Or if we've run out of outputs, we must perform a duplexing to get more. + if !self.input_buffer.is_empty() || self.output_buffer.is_empty() { + self.duplexing(); + } + + self.output_buffer + .pop() + .expect("Output buffer should be non-empty") + }) } } @@ -218,7 +221,9 @@ mod tests { (0..WIDTH / 2).for_each(|element| { assert_eq!( - duplex_challenger.sample(), + as CanSample>::sample( + &mut duplex_challenger + ), F::from_canonical_u8(element as u8) ); assert_eq!( @@ -229,7 +234,12 @@ mod tests { }); (0..WIDTH / 2).for_each(|i| { - assert_eq!(duplex_challenger.sample(), F::from_canonical_u8(0)); + assert_eq!( + as CanSample>::sample( + &mut duplex_challenger + ), + F::from_canonical_u8(0) + ); assert_eq!(duplex_challenger.input_buffer, vec![]); assert_eq!( duplex_challenger.output_buffer, diff --git a/challenger/src/grinding_challenger.rs b/challenger/src/grinding_challenger.rs index 4b3e6485..d08c24c2 100644 --- a/challenger/src/grinding_challenger.rs +++ b/challenger/src/grinding_challenger.rs @@ -1,34 +1,39 @@ -use p3_field::PrimeField64; +use p3_field::{Field, PrimeField64}; use p3_maybe_rayon::prelude::*; use p3_symmetric::CryptographicPermutation; use tracing::instrument; -use crate::{DuplexChallenger, FieldChallenger}; +use crate::{CanObserve, CanSampleBits, DuplexChallenger}; -pub trait GrindingChallenger: FieldChallenger + Clone { - // Can be overridden for more efficient methods not involving cloning, depending on the - // internals of the challenger. - #[instrument(name = "grind for proof-of-work witness", skip_all)] - fn grind(&mut self, bits: usize) -> F { - let witness = (0..F::ORDER_U64) - .into_par_iter() - .map(|i| F::from_canonical_u64(i)) - .find_any(|witness| self.clone().check_witness(bits, *witness)) - .expect("failed to find witness"); - assert!(self.check_witness(bits, witness)); - witness - } +pub trait GrindingChallenger: + CanObserve + CanSampleBits + Sync + Clone +{ + type Witness: Field; + + fn grind(&mut self, bits: usize) -> Self::Witness; #[must_use] - fn check_witness(&mut self, bits: usize, witness: F) -> bool { + fn check_witness(&mut self, bits: usize, witness: Self::Witness) -> bool { self.observe(witness); self.sample_bits(bits) == 0 } } -impl GrindingChallenger for DuplexChallenger +impl GrindingChallenger for DuplexChallenger where F: PrimeField64, P: CryptographicPermutation<[F; WIDTH]>, { + type Witness = F; + + #[instrument(name = "grind for proof-of-work witness", skip_all)] + fn grind(&mut self, bits: usize) -> Self::Witness { + let witness = (0..F::ORDER_U64) + .into_par_iter() + .map(|i| F::from_canonical_u64(i)) + .find_any(|witness| self.clone().check_witness(bits, *witness)) + .expect("failed to find witness"); + assert!(self.check_witness(bits, witness)); + witness + } } diff --git a/commit/src/adapters/extension_mmcs.rs b/commit/src/adapters/extension_mmcs.rs index 03eea6a8..9377586a 100644 --- a/commit/src/adapters/extension_mmcs.rs +++ b/commit/src/adapters/extension_mmcs.rs @@ -115,7 +115,9 @@ where InnerMat: Matrix, { fn width(&self) -> usize { - self.inner.width() * >::D + let d = >::D; + debug_assert!(self.inner.width() % d == 0); + self.inner.width() / d } fn height(&self) -> usize { diff --git a/commit/src/pcs.rs b/commit/src/pcs.rs index d0dd45b6..8a56d50c 100644 --- a/commit/src/pcs.rs +++ b/commit/src/pcs.rs @@ -2,6 +2,7 @@ use alloc::vec; use alloc::vec::Vec; +use core::fmt::Debug; use p3_challenger::FieldChallenger; use p3_field::{ExtensionField, Field}; @@ -25,7 +26,7 @@ pub trait Pcs> { /// The opening argument. type Proof: Serialize + DeserializeOwned; - type Error; + type Error: Debug; fn commit_batches(&self, polynomials: Vec) -> (Self::Commitment, Self::ProverData); diff --git a/field/src/extension/binomial_extension.rs b/field/src/extension/binomial_extension.rs index f8c4251d..39551332 100644 --- a/field/src/extension/binomial_extension.rs +++ b/field/src/extension/binomial_extension.rs @@ -13,13 +13,8 @@ use serde::{Deserialize, Serialize}; use super::{HasFrobenius, HasTwoAdicBionmialExtension}; use crate::extension::BinomiallyExtendable; use crate::field::Field; -use crate::restriction::Res; use crate::{field_to_array, AbstractExtensionField, AbstractField, ExtensionField, TwoAdicField}; -/// The algebra F_D[X] / (X^D - W) where F_D is the binomial extension field over `F`. -pub type BinomialExtensionAlgebra = - BinomialExtensionField>, D>; - #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, Serialize, Deserialize)] pub struct BinomialExtensionField { #[serde( diff --git a/field/src/field.rs b/field/src/field.rs index 69ce1b6a..32e5b248 100644 --- a/field/src/field.rs +++ b/field/src/field.rs @@ -242,6 +242,7 @@ pub trait PrimeField32: PrimeField64 { pub trait AbstractExtensionField: AbstractField + + From + Add + AddAssign + Sub diff --git a/field/src/lib.rs b/field/src/lib.rs index 5769e1db..8f483393 100644 --- a/field/src/lib.rs +++ b/field/src/lib.rs @@ -11,7 +11,6 @@ pub mod extension; mod field; mod helpers; mod packed; -mod restriction; pub use array::*; pub use batch_inverse::*; @@ -19,4 +18,3 @@ pub use exponentiation::*; pub use field::*; pub use helpers::*; pub use packed::*; -pub use restriction::*; diff --git a/field/src/restriction.rs b/field/src/restriction.rs deleted file mode 100644 index f1336f21..00000000 --- a/field/src/restriction.rs +++ /dev/null @@ -1,217 +0,0 @@ -use core::marker::PhantomData; -use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; - -use crate::{AbstractExtensionField, AbstractField, ExtensionField, Field}; - -/// The restriction of scalars from a field `EF` to a subfield `F`. -#[derive(Clone, Copy, Debug, Eq, PartialEq, Default, Hash)] -pub struct Res(EF, PhantomData); - -impl> Res { - /// Returns the underlying field element. - pub fn into_inner(self) -> EF { - self.0 - } - - pub const fn from_inner(e: EF) -> Self { - Self(e, PhantomData) - } -} - -impl> AbstractField for Res { - type F = F; - - fn from_f(f: F) -> Self { - Self::from_inner(EF::from_base(f)) - } - - fn zero() -> Self { - Self::from_inner(EF::zero()) - } - - fn one() -> Self { - Self::from_inner(EF::one()) - } - - fn two() -> Self { - Self::from_inner(EF::two()) - } - - fn from_bool(b: bool) -> Self { - Self::from_inner(EF::from_bool(b)) - } - - fn from_canonical_u8(n: u8) -> Self { - Self::from_inner(EF::from_canonical_u8(n)) - } - - fn from_canonical_u16(n: u16) -> Self { - Self::from_inner(EF::from_canonical_u16(n)) - } - - fn from_canonical_u32(n: u32) -> Self { - Self::from_inner(EF::from_canonical_u32(n)) - } - - fn from_canonical_u64(n: u64) -> Self { - Self::from_inner(EF::from_canonical_u64(n)) - } - - fn from_canonical_usize(n: usize) -> Self { - Self::from_inner(EF::from_canonical_usize(n)) - } - - fn from_wrapped_u32(n: u32) -> Self { - Self::from_inner(EF::from_wrapped_u32(n)) - } - - fn from_wrapped_u64(n: u64) -> Self { - Self::from_inner(EF::from_wrapped_u64(n)) - } - - fn neg_one() -> Self { - Self::from_inner(EF::neg_one()) - } - - fn generator() -> Self { - Self::from_inner(EF::generator()) - } -} - -impl> AbstractExtensionField for Res { - const D: usize = EF::D; - - fn from_base(b: F) -> Self { - Self::from_inner(EF::from_base(b)) - } - - fn from_base_fn F>(f: Fun) -> Self { - Self::from_inner(EF::from_base_fn(f)) - } - - fn from_base_slice(bs: &[F]) -> Self { - Self::from_inner(EF::from_base_slice(bs)) - } - - fn as_base_slice(&self) -> &[F] { - self.0.as_base_slice() - } -} - -impl> From for Res { - fn from(f: F) -> Self { - Res(EF::from_base(f), PhantomData) - } -} - -impl> Add for Res { - type Output = Self; - - fn add(self, other: Self) -> Self { - Self::from_inner(self.0 + other.0) - } -} - -impl> AddAssign for Res { - fn add_assign(&mut self, other: Self) { - self.0.add_assign(other.0); - } -} - -impl> Mul for Res { - type Output = Self; - - fn mul(self, other: Self) -> Self { - Self::from_inner(self.0 * other.0) - } -} - -impl> MulAssign for Res { - fn mul_assign(&mut self, other: Self) { - self.0.mul_assign(other.0); - } -} - -impl> MulAssign for Res { - fn mul_assign(&mut self, other: F) { - self.0.mul_assign(other); - } -} - -impl> Sub for Res { - type Output = Self; - - fn sub(self, other: Self) -> Self { - Self::from_inner(self.0 - other.0) - } -} - -impl> Neg for Res { - type Output = Self; - - fn neg(self) -> Self { - Self::from_inner(-self.0) - } -} - -impl> SubAssign for Res { - fn sub_assign(&mut self, other: Self) { - self.0.sub_assign(other.0); - } -} - -impl> Add for Res { - type Output = Self; - - fn add(self, other: F) -> Self { - Self::from_inner(self.0 + other) - } -} - -impl> AddAssign for Res { - fn add_assign(&mut self, other: F) { - self.0.add_assign(other); - } -} - -impl> Sub for Res { - type Output = Self; - - fn sub(self, other: F) -> Self { - Self::from_inner(self.0 - other) - } -} - -impl> SubAssign for Res { - fn sub_assign(&mut self, other: F) { - self.0.sub_assign(other); - } -} - -impl> Mul for Res { - type Output = Self; - - fn mul(self, other: F) -> Self { - Self::from_inner(self.0 * other) - } -} - -impl> Div for Res { - type Output = Self; - - fn div(self, other: Self) -> Self { - Self::from_inner(self.0 / other.0) - } -} - -impl> core::iter::Product for Res { - fn product>(iter: I) -> Self { - iter.fold(Self::one(), |acc, x| acc * x) - } -} - -impl> core::iter::Sum for Res { - fn sum>(iter: I) -> Self { - iter.fold(Self::zero(), |acc, x| acc + x) - } -} diff --git a/fri/Cargo.toml b/fri/Cargo.toml index 91dbc123..2694e896 100644 --- a/fri/Cargo.toml +++ b/fri/Cargo.toml @@ -7,8 +7,9 @@ license = "MIT OR Apache-2.0" [dependencies] p3-challenger = { path = "../challenger" } p3-commit = { path = "../commit" } +p3-dft = { path = "../dft" } p3-field = { path = "../field" } -p3-ldt = { path = "../ldt" } +p3-interpolation = { path = "../interpolation" } p3-matrix = { path = "../matrix" } p3-maybe-rayon = { path = "../maybe-rayon" } p3-util = { path = "../util" } diff --git a/fri/src/config.rs b/fri/src/config.rs index 618abf0e..d0ece0bb 100644 --- a/fri/src/config.rs +++ b/fri/src/config.rs @@ -1,88 +1,12 @@ -use core::marker::PhantomData; - -use p3_challenger::{CanObserve, GrindingChallenger}; -use p3_commit::{DirectMmcs, Mmcs}; -use p3_field::{ExtensionField, PrimeField64, TwoAdicField}; - -pub trait FriConfig { - type Val: PrimeField64; - type Challenge: ExtensionField + TwoAdicField; - - type InputMmcs: Mmcs; - type CommitPhaseMmcs: DirectMmcs; - - type Challenger: GrindingChallenger - + CanObserve<>::Commitment>; - - fn commit_phase_mmcs(&self) -> &Self::CommitPhaseMmcs; - - fn num_queries(&self) -> usize; - - fn log_blowup(&self) -> usize; - - fn blowup(&self) -> usize { - 1 << self.log_blowup() - } - - fn proof_of_work_bits(&self) -> usize; -} - -pub struct FriConfigImpl { - log_blowup: usize, - num_queries: usize, - proof_of_work_bits: usize, - commit_phase_mmcs: CommitPhaseMmcs, - _phantom: PhantomData<(Val, Challenge, InputMmcs, Challenger)>, -} - -impl - FriConfigImpl -{ - pub fn new( - log_blowup: usize, - num_queries: usize, - proof_of_work_bits: usize, - commit_phase_mmcs: CommitPhaseMmcs, - ) -> Self { - Self { - log_blowup, - num_queries, - commit_phase_mmcs, - proof_of_work_bits, - _phantom: PhantomData, - } - } +pub struct FriConfig { + pub log_blowup: usize, + pub num_queries: usize, + pub proof_of_work_bits: usize, + pub mmcs: M, } -impl FriConfig - for FriConfigImpl -where - Val: PrimeField64, - Challenge: ExtensionField + TwoAdicField, - InputMmcs: Mmcs, - CommitPhaseMmcs: DirectMmcs, - Challenger: - GrindingChallenger + CanObserve<>::Commitment>, -{ - type Val = Val; - type Challenge = Challenge; - type InputMmcs = InputMmcs; - type CommitPhaseMmcs = CommitPhaseMmcs; - type Challenger = Challenger; - - fn commit_phase_mmcs(&self) -> &CommitPhaseMmcs { - &self.commit_phase_mmcs - } - - fn num_queries(&self) -> usize { - self.num_queries - } - - fn log_blowup(&self) -> usize { - self.log_blowup - } - - fn proof_of_work_bits(&self) -> usize { - self.proof_of_work_bits +impl FriConfig { + pub fn blowup(&self) -> usize { + 1 << self.log_blowup } } diff --git a/fri/src/lib.rs b/fri/src/lib.rs index 2ede7fa2..14b7dfcb 100644 --- a/fri/src/lib.rs +++ b/fri/src/lib.rs @@ -4,72 +4,14 @@ extern crate alloc; -use alloc::vec::Vec; - -use p3_commit::Mmcs; -use p3_ldt::{Ldt, LdtBasedPcs}; -use p3_matrix::Dimensions; -use verifier::VerificationErrorForFriConfig; - -use crate::prover::prove; -use crate::verifier::verify; - mod config; mod fold_even_odd; -mod matrix_reducer; mod proof; -mod prover; -mod verifier; +pub mod prover; +pub mod two_adic_pcs; +pub mod verifier; pub use config::*; pub use fold_even_odd::*; pub use proof::*; - -pub struct FriLdt { - pub config: FC, -} - -impl Ldt for FriLdt { - type Proof = FriProof; - type Error = VerificationErrorForFriConfig; - - fn log_blowup(&self) -> usize { - self.config.log_blowup() - } - - fn prove( - &self, - input_mmcs: &[FC::InputMmcs], - input_data: &[&>::ProverData], - challenger: &mut FC::Challenger, - ) -> Self::Proof { - prove::(&self.config, input_mmcs, input_data, challenger) - } - - fn verify( - &self, - input_mmcs: &[FC::InputMmcs], - input_dims: &[Vec], - input_commits: &[>::Commitment], - proof: &Self::Proof, - challenger: &mut FC::Challenger, - ) -> Result<(), Self::Error> { - verify::( - &self.config, - input_mmcs, - input_dims, - input_commits, - proof, - challenger, - ) - } -} - -pub type FriBasedPcs = LdtBasedPcs< - ::Val, - ::Challenge, - Dft, - Mmcs, - FriLdt, - Challenger, ->; +pub use two_adic_pcs::*; diff --git a/fri/src/matrix_reducer.rs b/fri/src/matrix_reducer.rs deleted file mode 100644 index 0d4cbd9a..00000000 --- a/fri/src/matrix_reducer.rs +++ /dev/null @@ -1,176 +0,0 @@ -use alloc::vec; -use alloc::vec::Vec; - -use itertools::Itertools; -use p3_field::{ExtensionField, Field, PackedField}; -use p3_matrix::MatrixRows; -use p3_maybe_rayon::prelude::*; -use p3_util::indices_arr; -use tracing::instrument; - -// seems to be the sweet spot, could be tweaked based on benches -const BATCH_SIZE: usize = 8; - -/// Optimized matrix reducer. Only works for binomial extension fields, or extension fields -/// in which multiplication and addition with a base field element are simple elementwise operations. -pub(crate) struct MatrixReducer { - alpha: EF, - alpha_pow_width: EF, - transposed_alphas: Vec<[F::Packing; BATCH_SIZE]>, -} - -impl> MatrixReducer { - pub(crate) fn new(alpha: EF) -> Self { - let alpha_pows = alpha - .powers() - .take(F::Packing::WIDTH * BATCH_SIZE) - .collect_vec(); - let transposed_alphas = (0..EF::D) - .map(|i| { - indices_arr::().map(|j| { - F::Packing::from_fn(|k| { - alpha_pows[j * F::Packing::WIDTH + k].as_base_slice()[i] - }) - }) - }) - .collect_vec(); - Self { - alpha, - alpha_pow_width: alpha.exp_u64(F::Packing::WIDTH as u64), - transposed_alphas, - } - } - - #[instrument(name = "fold in matrices", level = "debug", skip(self, reduced, mats))] - pub(crate) fn reduce_matrices + Sync>( - &self, - reduced: &mut [EF], - height: usize, - mats: &[M], - ) { - // precompute alpha_pows, since we are not doing horner - let mut current_alpha_pow = EF::one(); - let mut alpha_pows = vec![]; - for mat in mats { - let num_packed = mat.width() / F::Packing::WIDTH; - let num_leftover = mat.width() % F::Packing::WIDTH; - for chunk in &(0..num_packed).chunks(BATCH_SIZE) { - alpha_pows.push(current_alpha_pow); - current_alpha_pow *= self.alpha_pow_width.exp_u64(chunk.count() as u64); - } - for _ in 0..num_leftover { - alpha_pows.push(current_alpha_pow); - current_alpha_pow *= self.alpha; - } - } - - /* - We have packed base elements ys and extension field α^n's - We want α^0 * ys[0] + α^1 * ys[1] + ... - - Level 0: - alpha.powers().zip(cols).map(|alpha_pow, col| alpha_pow * col).sum() - - Level 1: (assume Packing::WIDTH=4, D=2, for clarity, although D will usually be higher) - - transposed_alphas - α^0 α^1 α^2 α^3 - [1] [.] [.] [.] - [1] [.] [.] [.] - ys:[a b c d] [e f g h] ... - - We multiply ys vertically, then sum horizontally. - Aka, if β is α^0*a + α^1*b + α^2*c + α^3*d, then each limb of β is - determined by multiplying the packed ys by the appropriate row of - transposed_alphas, then summing horizontally. - - This assumes we are in an extension field where multiplication and addition - by a base element is simply elementwise. - - Then, to fold it into our running reduction, we perform one extension multiplication. - In this scheme, the extension mul takes about 40% of the time in this function. - - Level 2: Batching. To delay the horizontal sum and extension mul as much as possible, - we precompute even more columns of transposed_alphas, and sum the packed results per - batch, and only do one horizontal sum and extension mul per batch. - - */ - - reduced - .into_par_iter() - .enumerate() - .for_each(|(r, reduced_row)| { - let mut alpha_pow_iter = alpha_pows.iter(); - for mat in mats { - let row_vec = mat.row_vec(r); - let (packed_row, sfx) = F::Packing::pack_slice_with_suffix(&row_vec); - for packed_col_chunk in packed_row.chunks(BATCH_SIZE) { - let chunk_sum = EF::from_base_fn(|i| { - let chunk_limb_sum = packed_col_chunk - .iter() - .zip(self.transposed_alphas[i]) - .map(|(&packed_col, packed_alpha)| packed_col * packed_alpha) - .sum::(); - chunk_limb_sum.as_slice().iter().copied().sum::() - }); - *reduced_row += *alpha_pow_iter.next().unwrap() * chunk_sum; - } - for &col in sfx { - *reduced_row += *alpha_pow_iter.next().unwrap() * col; - } - } - }) - } -} - -#[cfg(test)] -mod tests { - use alloc::vec; - - use p3_baby_bear::BabyBear; - use p3_field::extension::BinomialExtensionField; - use p3_field::AbstractField; - use p3_matrix::dense::RowMajorMatrix; - use p3_matrix::MatrixRows; - use rand::Rng; - - use super::MatrixReducer; - - type F = BabyBear; - // 5 instead of 4 to make sure it works when EF::D != Packing::WIDTH - type EF = BinomialExtensionField; - - #[test] - fn test_matrix_reducer() { - let mut rng = rand::thread_rng(); - let alpha: EF = rng.gen(); - let height = 32; - let mats0: &[RowMajorMatrix] = &[ - RowMajorMatrix::rand(&mut rng, height, 37), - RowMajorMatrix::rand(&mut rng, height, 13), - ]; - let mats1: &[RowMajorMatrix] = &[ - RowMajorMatrix::rand(&mut rng, height, 41), - RowMajorMatrix::rand(&mut rng, height, 10), - ]; - let reducer = MatrixReducer::new(alpha); - let mut reduced = vec![EF::zero(); height]; - reducer.reduce_matrices(&mut reduced, height, mats0); - reducer.reduce_matrices(&mut reduced, height, mats1); - - let mut correct = vec![EF::zero(); height]; - for (r, correct_reduced) in correct.iter_mut().enumerate() { - for batch in [mats0, mats1] { - let mut current = EF::one(); - for mat in batch { - for col in mat.row(r) { - *correct_reduced += current * col; - current *= alpha; - } - } - } - } - - assert_eq!(reduced, correct); - } -} diff --git a/fri/src/proof.rs b/fri/src/proof.rs index edb1df6c..c75cfcaf 100644 --- a/fri/src/proof.rs +++ b/fri/src/proof.rs @@ -1,49 +1,39 @@ use alloc::vec::Vec; use p3_commit::Mmcs; +use p3_field::Field; use serde::{Deserialize, Serialize}; -use crate::FriConfig; - #[derive(Serialize, Deserialize)] -#[serde(bound = "")] -pub struct FriProof { - pub(crate) commit_phase_commits: Vec<>::Commitment>, - pub(crate) query_proofs: Vec>, +#[serde(bound( + serialize = "Witness: Serialize", + deserialize = "Witness: Deserialize<'de>" +))] +pub struct FriProof, Witness> { + pub(crate) commit_phase_commits: Vec, + pub(crate) query_proofs: Vec>, // This could become Vec if this library was generalized to support non-constant // final polynomials. - pub(crate) final_poly: FC::Challenge, - pub(crate) pow_witness: FC::Val, + pub(crate) final_poly: F, + pub(crate) pow_witness: Witness, } #[derive(Serialize, Deserialize)] #[serde(bound = "")] -pub struct QueryProof { - /// For each input commitment, this contains openings of each matrix at the queried location, - /// along with an opening proof. - pub(crate) input_openings: Vec>, - +pub struct QueryProof> { /// For each commit phase commitment, this contains openings of a commit phase codeword at the /// queried location, along with an opening proof. - pub(crate) commit_phase_openings: Vec>, -} - -/// Openings of each input codeword at the queried location, along with an opening proof, for a -/// single commitment round. -#[derive(Serialize, Deserialize)] -pub struct InputOpening { - /// The opening of each input codeword at the queried location. - pub(crate) opened_values: Vec>, - - pub(crate) opening_proof: >::Proof, + pub(crate) commit_phase_openings: Vec>, } #[derive(Serialize, Deserialize)] -pub struct CommitPhaseProofStep { +// #[serde(bound(serialize = "F: Serialize"))] +#[serde(bound = "")] +pub struct CommitPhaseProofStep> { /// The opening of the commit phase codeword at the sibling location. // This may change to Vec if the library is generalized to support other FRI // folding arities besides 2, meaning that there can be multiple siblings. - pub(crate) sibling_value: FC::Challenge, + pub(crate) sibling_value: F, - pub(crate) opening_proof: >::Proof, + pub(crate) opening_proof: M::Proof, } diff --git a/fri/src/prover.rs b/fri/src/prover.rs index 7f0b914e..706afbc5 100644 --- a/fri/src/prover.rs +++ b/fri/src/prover.rs @@ -1,84 +1,64 @@ use alloc::vec; use alloc::vec::Vec; -use p3_challenger::{CanObserve, CanSampleBits, FieldChallenger, GrindingChallenger}; +use itertools::Itertools; +use p3_challenger::{CanObserve, CanSample, GrindingChallenger}; use p3_commit::{DirectMmcs, Mmcs}; -use p3_field::AbstractField; +use p3_field::{Field, TwoAdicField}; use p3_matrix::dense::RowMajorMatrix; -use p3_matrix::Matrix; -use p3_util::log2_strict_usize; use tracing::{info_span, instrument}; use crate::fold_even_odd::fold_even_odd; -use crate::matrix_reducer::MatrixReducer; -use crate::{CommitPhaseProofStep, FriConfig, FriProof, InputOpening, QueryProof}; +use crate::{CommitPhaseProofStep, FriConfig, FriProof, QueryProof}; #[instrument(name = "FRI prover", skip_all)] -pub(crate) fn prove( - config: &FC, - input_mmcs: &[FC::InputMmcs], - input_data: &[&>::ProverData], - challenger: &mut FC::Challenger, -) -> FriProof { - let max_height = input_mmcs - .iter() - .zip(input_data) - .map(|(mmcs, commit)| mmcs.get_max_height(commit)) - .max() - .unwrap_or_else(|| panic!("No matrices?")); - let log_max_height = log2_strict_usize(max_height); - - let commit_phase_result = - commit_phase::(config, input_mmcs, input_data, log_max_height, challenger); - - let pow_witness = challenger.grind(config.proof_of_work_bits()); - - let query_indices: Vec = (0..config.num_queries()) +pub fn prove( + config: &FriConfig, + input: &[Option>; 32], + challenger: &mut Challenger, +) -> (FriProof, Vec) +where + F: TwoAdicField, + M: DirectMmcs, + Challenger: GrindingChallenger + CanObserve + CanSample, +{ + let log_max_height = input.iter().rposition(Option::is_some).unwrap(); + + let commit_phase_result = commit_phase(config, input, log_max_height, challenger); + + let pow_witness = challenger.grind(config.proof_of_work_bits); + + let query_indices: Vec = (0..config.num_queries) .map(|_| challenger.sample_bits(log_max_height)) .collect(); let query_proofs = info_span!("query phase").in_scope(|| { query_indices - .into_iter() - .map(|index| { - answer_query( - config, - input_mmcs, - input_data, - &commit_phase_result.data, - index, - ) - }) + .iter() + .map(|&index| answer_query(config, &commit_phase_result.data, index)) .collect() }); - FriProof { - commit_phase_commits: commit_phase_result.commits, - query_proofs, - final_poly: commit_phase_result.final_poly, - pow_witness, - } + ( + FriProof { + commit_phase_commits: commit_phase_result.commits, + query_proofs, + final_poly: commit_phase_result.final_poly, + pow_witness, + }, + query_indices, + ) } -fn answer_query( - config: &FC, - input_mmcs: &[FC::InputMmcs], - input_data: &[&>::ProverData], - commit_phase_commits: &[>::ProverData], +fn answer_query( + config: &FriConfig, + commit_phase_commits: &[M::ProverData], index: usize, -) -> QueryProof { - let input_openings = input_mmcs - .iter() - .zip(input_data) - .map(|(mmcs, commit)| { - let (opened_values, opening_proof) = mmcs.open_batch(index, commit); - InputOpening { - opened_values, - opening_proof, - } - }) - .collect(); - +) -> QueryProof +where + F: Field, + M: Mmcs, +{ let commit_phase_openings = commit_phase_commits .iter() .enumerate() @@ -87,8 +67,7 @@ fn answer_query( let index_i_sibling = index_i ^ 1; let index_pair = index_i >> 1; - let (mut opened_rows, opening_proof) = - config.commit_phase_mmcs().open_batch(index_pair, commit); + let (mut opened_rows, opening_proof) = config.mmcs.open_batch(index_pair, commit); assert_eq!(opened_rows.len(), 1); let opened_row = opened_rows.pop().unwrap(); assert_eq!(opened_row.len(), 2, "Committed data should be in pairs"); @@ -102,53 +81,39 @@ fn answer_query( .collect(); QueryProof { - input_openings, commit_phase_openings, } } #[instrument(name = "commit phase", skip_all)] -fn commit_phase( - config: &FC, - input_mmcs: &[FC::InputMmcs], - input_data: &[&>::ProverData], +fn commit_phase( + config: &FriConfig, + input: &[Option>; 32], log_max_height: usize, - challenger: &mut FC::Challenger, -) -> CommitPhaseResult { - let max_height = 1 << log_max_height; - - let mut matrices_by_log_height: Vec> = vec![]; - matrices_by_log_height.resize_with(log_max_height + 1, Default::default); - for (mmcs, commit) in input_mmcs.iter().zip(input_data) { - for mat in mmcs.get_matrices(commit) { - matrices_by_log_height[log2_strict_usize(mat.height())].push(mat); - } - } - - let largest_matrices = &matrices_by_log_height[log_max_height]; - let alpha: FC::Challenge = challenger.sample_ext_element(); - let alpha_reducer = MatrixReducer::new(alpha); - let mut current = vec![FC::Challenge::zero(); max_height]; - alpha_reducer.reduce_matrices(&mut current, max_height, largest_matrices); + challenger: &mut Challenger, +) -> CommitPhaseResult +where + F: TwoAdicField, + M: DirectMmcs, + Challenger: CanObserve + CanSample, +{ + let mut current = input[log_max_height].as_ref().unwrap().clone(); let mut commits = vec![]; let mut data = vec![]; - for log_folded_height in (config.log_blowup()..log_max_height).rev() { - let folded_height = 1 << log_folded_height; - // TODO: avoid cloning + for log_folded_height in (config.log_blowup..log_max_height).rev() { let leaves = RowMajorMatrix::new(current.clone(), 2); - let (commit, prover_data) = config.commit_phase_mmcs().commit_matrix(leaves); + let (commit, prover_data) = config.mmcs.commit_matrix(leaves); challenger.observe(commit.clone()); commits.push(commit); data.push(prover_data); - let beta: FC::Challenge = challenger.sample_ext_element(); + let beta: F = challenger.sample(); current = fold_even_odd(current, beta); - let matrices = &matrices_by_log_height[log_folded_height]; - if !matrices.is_empty() { - alpha_reducer.reduce_matrices(&mut current, folded_height, matrices); + if let Some(v) = &input[log_folded_height] { + current.iter_mut().zip_eq(v).for_each(|(c, v)| *c += *v); } } @@ -166,8 +131,8 @@ fn commit_phase( } } -struct CommitPhaseResult { - commits: Vec<>::Commitment>, - data: Vec<>::ProverData>, - final_poly: FC::Challenge, +struct CommitPhaseResult> { + commits: Vec, + data: Vec, + final_poly: F, } diff --git a/fri/src/two_adic_pcs.rs b/fri/src/two_adic_pcs.rs new file mode 100644 index 00000000..fbbb606e --- /dev/null +++ b/fri/src/two_adic_pcs.rs @@ -0,0 +1,579 @@ +use alloc::vec; +use alloc::vec::Vec; +use core::fmt::{Debug, Formatter}; +use core::marker::PhantomData; + +use itertools::{izip, Itertools}; +use p3_challenger::{CanObserve, CanSample, FieldChallenger, GrindingChallenger}; +use p3_commit::{DirectMmcs, Mmcs, OpenedValues, Pcs, UnivariatePcs, UnivariatePcsWithLde}; +use p3_dft::TwoAdicSubgroupDft; +use p3_field::{ + batch_multiplicative_inverse, cyclic_subgroup_coset_known_order, AbstractField, ExtensionField, + Field, PackedField, TwoAdicField, +}; +use p3_interpolation::interpolate_coset; +use p3_matrix::bitrev::{BitReversableMatrix, BitReversedMatrixView}; +use p3_matrix::dense::RowMajorMatrixView; +use p3_matrix::{Dimensions, Matrix, MatrixRows}; +use p3_maybe_rayon::prelude::*; +use p3_util::linear_map::LinearMap; +use p3_util::{log2_strict_usize, reverse_bits_len, reverse_slice_index_bits, VecExt}; +use serde::{Deserialize, Serialize}; +use tracing::{info_span, instrument}; + +use crate::verifier::{self, FriError}; +use crate::{prover, FriConfig, FriProof}; + +/// We group all of our type bounds into this trait to reduce duplication across signatures. +pub trait TwoAdicFriPcsGenericConfig: Default { + type Val: TwoAdicField; + type Challenge: TwoAdicField + ExtensionField; + type Challenger: FieldChallenger + + GrindingChallenger + + CanObserve<>::Commitment> + + CanSample; + type Dft: TwoAdicSubgroupDft; + type InputMmcs: 'static + + for<'a> DirectMmcs = RowMajorMatrixView<'a, Self::Val>>; + type FriMmcs: DirectMmcs; +} + +pub struct TwoAdicFriPcsConfig( + PhantomData<(Val, Challenge, Challenger, Dft, InputMmcs, FriMmcs)>, +); +impl Default + for TwoAdicFriPcsConfig +{ + fn default() -> Self { + Self(PhantomData) + } +} + +impl TwoAdicFriPcsGenericConfig + for TwoAdicFriPcsConfig +where + Val: TwoAdicField, + Challenge: TwoAdicField + ExtensionField, + Challenger: FieldChallenger + + GrindingChallenger + + CanObserve<>::Commitment> + + CanSample, + Dft: TwoAdicSubgroupDft, + InputMmcs: 'static + for<'a> DirectMmcs = RowMajorMatrixView<'a, Val>>, + FriMmcs: DirectMmcs, +{ + type Val = Val; + type Challenge = Challenge; + type Challenger = Challenger; + type Dft = Dft; + type InputMmcs = InputMmcs; + type FriMmcs = FriMmcs; +} + +pub struct TwoAdicFriPcs { + fri: FriConfig, + dft: C::Dft, + mmcs: C::InputMmcs, +} + +impl TwoAdicFriPcs { + pub fn new(fri: FriConfig, dft: C::Dft, mmcs: C::InputMmcs) -> Self { + Self { fri, dft, mmcs } + } +} + +pub enum VerificationError { + InputMmcsError(>::Error), + FriError(FriError<>::Error>), +} + +impl Debug for VerificationError { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + match self { + VerificationError::InputMmcsError(e) => { + f.debug_tuple("InputMmcsError").field(e).finish() + } + VerificationError::FriError(e) => f.debug_tuple("FriError").field(e).finish(), + } + } +} + +#[derive(Serialize, Deserialize)] +#[serde(bound = "")] +pub struct TwoAdicFriPcsProof { + pub(crate) fri_proof: FriProof, + /// For each query, for each committed batch, query openings for that batch + pub(crate) query_openings: Vec>>, +} + +#[derive(Serialize, Deserialize)] +pub struct BatchOpening { + pub(crate) opened_values: Vec>, + pub(crate) opening_proof: >::Proof, +} + +impl> Pcs for TwoAdicFriPcs { + type Commitment = >::Commitment; + type ProverData = >::ProverData; + type Proof = TwoAdicFriPcsProof; + type Error = VerificationError; + + fn commit_batches(&self, polynomials: Vec) -> (Self::Commitment, Self::ProverData) { + let ones = vec![C::Val::one(); polynomials.len()]; + self.commit_shifted_batches(polynomials, &ones) + } +} + +impl> + UnivariatePcsWithLde for TwoAdicFriPcs +{ + type Lde<'a> = BitReversedMatrixView<>::Mat<'a>> where Self: 'a; + + fn coset_shift(&self) -> C::Val { + C::Val::generator() + } + + fn log_blowup(&self) -> usize { + self.fri.log_blowup + } + + fn get_ldes<'a, 'b>(&'a self, prover_data: &'b Self::ProverData) -> Vec> + where + 'a: 'b, + { + // We committed to the bit-reversed LDE, so now we wrap it to return in natural order. + self.mmcs + .get_matrices(prover_data) + .into_iter() + .map(BitReversedMatrixView::new) + .collect() + } + + fn commit_shifted_batches( + &self, + polynomials: Vec, + coset_shifts: &[C::Val], + ) -> (Self::Commitment, Self::ProverData) { + let ldes = info_span!("compute all coset LDEs").in_scope(|| { + polynomials + .into_iter() + .zip_eq(coset_shifts) + .map(|(poly, coset_shift)| { + let shift = C::Val::generator() / *coset_shift; + let input = poly.to_row_major_matrix(); + // Commit to the bit-reversed LDE. + self.dft + .coset_lde_batch(input, self.fri.log_blowup, shift) + .bit_reverse_rows() + .to_row_major_matrix() + }) + .collect() + }); + self.mmcs.commit(ldes) + } +} + +impl> + UnivariatePcs for TwoAdicFriPcs +{ + #[instrument(name = "open_multi_batches", skip_all)] + fn open_multi_batches( + &self, + prover_data_and_points: &[(&Self::ProverData, &[Vec])], + challenger: &mut C::Challenger, + ) -> (OpenedValues, Self::Proof) { + // Batch combination challenge + let alpha = >::sample(challenger); + + /* + + A quick rundown of the optimizations in this function: + We are trying to compute sum_i alpha^i * (p(X) - y)/(X - z), + for each z an opening point, y = p(z). Each p(X) is given as evaluations in bit-reversed order + in the columns of the matrices. y is computed by barycentric interpolation. + X and p(X) are in the base field; alpha, y and z are in the extension. + The primary goal is to minimize extension multiplications. + + - Instead of computing all alpha^i, we just compute alpha^i for i up to the largest width + of a matrix, then multiply by an "alpha offset" when accumulating. + a^0 x0 + a^1 x1 + a^2 x2 + a^3 x3 + ... + = a^0 ( a^0 x0 + a^1 x1 ) + a^2 ( a^0 x0 + a^1 x1 ) + ... + (see `alpha_pows`, `alpha_pow_offset`, `num_reduced`) + + - For each unique point z, we precompute 1/(X-z) for the largest subgroup opened at this point. + Since we compute it in bit-reversed order, smaller subgroups can simply truncate the vector. + (see `inv_denoms`) + + - Then, for each matrix (with columns p_i) and opening point z, we want: + for each row (corresponding to subgroup element X): + reduced[X] += alpha_offset * sum_i [ alpha^i * inv_denom[X] * (p_i[X] - y[i]) ] + + We can factor out inv_denom, and expand what's left: + reduced[X] += alpha_offset * inv_denom[X] * sum_i [ alpha^i * p_i[X] - alpha^i * y[i] ] + + And separate the sum: + reduced[X] += alpha_offset * inv_denom[X] * sum_i [ alpha^i * p_i[X] ] - sum_i [ alpha^i * y[i] ] + + And now the last sum doesn't depend on X, so we can precompute that for the matrix, too. + So the hot loop (that depends on both X and i) is just: + sum_i [ alpha^i * p_i[X] ] + + with alpha^i an extension, p_i[X] a base + + */ + + let mats_and_points = prover_data_and_points + .iter() + .map(|(data, points)| (self.mmcs.get_matrices(data), *points)) + .collect_vec(); + + let max_width = mats_and_points + .iter() + .flat_map(|(mats, _)| mats) + .map(|mat| mat.width()) + .max() + .unwrap(); + + let alpha_reducer = PowersReducer::::new(alpha, max_width); + + // For each unique opening point z, we will find the largest degree bound + // for that point, and precompute 1/(X - z) for the largest subgroup (in bitrev order). + let inv_denoms = compute_inverse_denominators(&mats_and_points, C::Val::generator()); + + let mut all_opened_values: OpenedValues = vec![]; + let mut reduced_openings: [_; 32] = core::array::from_fn(|_| None); + let mut num_reduced = [0; 32]; + + for (data, points) in prover_data_and_points { + let mats = self.mmcs.get_matrices(data); + let opened_values_for_round = all_opened_values.pushed_mut(vec![]); + for (mat, points_for_mat) in izip!(mats, *points) { + let log_height = log2_strict_usize(mat.height()); + let reduced_opening_for_log_height = reduced_openings[log_height] + .get_or_insert_with(|| vec![C::Challenge::zero(); mat.height()]); + debug_assert_eq!(reduced_opening_for_log_height.len(), mat.height()); + + let opened_values_for_mat = opened_values_for_round.pushed_mut(vec![]); + for &point in points_for_mat { + let _guard = + info_span!("reduce matrix quotient", dims = %mat.dimensions()).entered(); + + // Use Barycentric interpolation to evaluate the matrix at the given point. + let ys = info_span!("compute opened values with Lagrange interpolation") + .in_scope(|| { + let (low_coset, _) = + mat.split_rows(mat.height() >> self.fri.log_blowup); + interpolate_coset( + &BitReversedMatrixView::new(low_coset), + C::Val::generator(), + point, + ) + }); + + let alpha_pow_offset = alpha.exp_u64(num_reduced[log_height] as u64); + let sum_alpha_pows_times_y = alpha_reducer.reduce_ext(&ys); + + info_span!("reduce rows").in_scope(|| { + reduced_opening_for_log_height + .par_iter_mut() + .zip_eq(mat.par_rows()) + // This might be longer, but zip will truncate to smaller subgroup + // (which is ok because it's bitrev) + .zip(inv_denoms.get(&point).unwrap()) + .for_each(|((reduced_opening, row), &inv_denom)| { + let row_sum = alpha_reducer.reduce_base(row); + *reduced_opening += inv_denom + * alpha_pow_offset + * (row_sum - sum_alpha_pows_times_y); + }); + }); + + num_reduced[log_height] += mat.width(); + opened_values_for_mat.push(ys); + } + } + } + + let (fri_proof, query_indices) = prover::prove(&self.fri, &reduced_openings, challenger); + + let query_openings = query_indices + .into_iter() + .map(|index| { + prover_data_and_points + .iter() + .map(|(data, _)| { + let (opened_values, opening_proof) = self.mmcs.open_batch(index, data); + BatchOpening { + opened_values, + opening_proof, + } + }) + .collect() + }) + .collect(); + + ( + all_opened_values, + TwoAdicFriPcsProof { + fri_proof, + query_openings, + }, + ) + } + + fn verify_multi_batches( + &self, + commits_and_points: &[(Self::Commitment, &[Vec])], + dims: &[Vec], + values: OpenedValues, + proof: &Self::Proof, + challenger: &mut C::Challenger, + ) -> Result<(), Self::Error> { + // Batch combination challenge + let alpha = >::sample(challenger); + + let fri_challenges = + verifier::verify_shape_and_sample_challenges(&self.fri, &proof.fri_proof, challenger) + .map_err(VerificationError::FriError)?; + + let log_max_height = proof.fri_proof.commit_phase_commits.len() + self.fri.log_blowup; + + let reduced_openings: Vec<[C::Challenge; 32]> = proof + .query_openings + .iter() + .zip(&fri_challenges.query_indices) + .map(|(query_opening, &index)| { + let mut ro = [C::Challenge::zero(); 32]; + let mut alpha_pow = [C::Challenge::one(); 32]; + for (batch_opening, batch_dims, (batch_commit, batch_points), batch_at_z) in + izip!(query_opening, dims, commits_and_points, &values) + { + self.mmcs.verify_batch( + batch_commit, + batch_dims, + index, + &batch_opening.opened_values, + &batch_opening.opening_proof, + )?; + for (mat_opening, mat_dims, mat_points, mat_at_z) in izip!( + &batch_opening.opened_values, + batch_dims, + *batch_points, + batch_at_z + ) { + let log_height = log2_strict_usize(mat_dims.height) + self.fri.log_blowup; + + let bits_reduced = log_max_height - log_height; + let rev_reduced_index = reverse_bits_len(index >> bits_reduced, log_height); + + let x = C::Val::generator() + * C::Val::two_adic_generator(log_height) + .exp_u64(rev_reduced_index as u64); + + for (&z, ps_at_z) in izip!(mat_points, mat_at_z) { + for (&p_at_x, &p_at_z) in izip!(mat_opening, ps_at_z) { + let quotient = (-p_at_z + p_at_x) / (-z + x); + ro[log_height] += alpha_pow[log_height] * quotient; + alpha_pow[log_height] *= alpha; + } + } + } + } + Ok(ro) + }) + .collect::, >::Error>>() + .map_err(VerificationError::InputMmcsError)?; + + verifier::verify_challenges( + &self.fri, + &proof.fri_proof, + &fri_challenges, + &reduced_openings, + ) + .map_err(VerificationError::FriError)?; + + Ok(()) + } +} + +#[instrument(skip_all)] +fn compute_inverse_denominators, M: Matrix>( + mats_and_points: &[(Vec, &[Vec])], + coset_shift: F, +) -> LinearMap> { + let mut max_log_height_for_point: LinearMap = LinearMap::new(); + for (mats, points) in mats_and_points { + for (mat, points_for_mat) in izip!(mats, *points) { + let log_height = log2_strict_usize(mat.height()); + for &z in points_for_mat { + if let Some(lh) = max_log_height_for_point.get_mut(&z) { + *lh = core::cmp::max(*lh, log_height); + } else { + max_log_height_for_point.insert(z, log_height); + } + } + } + } + + // Compute the largest subgroup we will use, in bitrev order. + let max_log_height = *max_log_height_for_point.values().max().unwrap(); + let mut subgroup = cyclic_subgroup_coset_known_order( + F::two_adic_generator(max_log_height), + coset_shift, + 1 << max_log_height, + ) + .collect_vec(); + reverse_slice_index_bits(&mut subgroup); + + max_log_height_for_point + .into_iter() + .map(|(z, log_height)| { + ( + z, + batch_multiplicative_inverse( + &subgroup[..(1 << log_height)] + .iter() + .map(|&x| EF::from_base(x) - z) + .collect_vec(), + ), + ) + }) + .collect() +} + +struct PowersReducer { + powers: Vec, + // If EF::D = 2 and powers is [01 23 45 67], + // this holds [[02 46] [13 57]] + transposed_packed: Vec>, +} + +impl> PowersReducer { + fn new(base: EF, max_width: usize) -> Self { + let powers: Vec = base + .powers() + .take(max_width.next_multiple_of(F::Packing::WIDTH)) + .collect(); + + let transposed_packed: Vec> = transpose_vec( + (0..EF::D) + .map(|d| { + F::Packing::pack_slice( + &powers.iter().map(|a| a.as_base_slice()[d]).collect_vec(), + ) + .to_vec() + }) + .collect(), + ); + + Self { + powers, + transposed_packed, + } + } + + // Compute sum_i base^i * x_i + fn reduce_ext(&self, xs: &[EF]) -> EF { + self.powers.iter().zip(xs).map(|(&pow, &x)| pow * x).sum() + } + + // Same as `self.powers.iter().zip(xs).map(|(&pow, &x)| pow * x).sum()` + fn reduce_base(&self, xs: &[F]) -> EF { + let (xs_packed, xs_sfx) = F::Packing::pack_slice_with_suffix(xs); + let mut sums = (0..EF::D).map(|_| F::Packing::zero()).collect::>(); + for (&x, pows) in izip!(xs_packed, &self.transposed_packed) { + for d in 0..EF::D { + sums[d] += x * pows[d]; + } + } + let packed_sum = EF::from_base_fn(|d| sums[d].as_slice().iter().copied().sum()); + let sfx_sum = xs_sfx + .iter() + .zip(&self.powers[(xs_packed.len() * F::Packing::WIDTH)..]) + .map(|(&x, &pow)| pow * x) + .sum::(); + packed_sum + sfx_sum + } +} + +fn transpose_vec(v: Vec>) -> Vec> { + assert!(!v.is_empty()); + let len = v[0].len(); + let mut iters: Vec<_> = v.into_iter().map(|n| n.into_iter()).collect(); + (0..len) + .map(|_| { + iters + .iter_mut() + .map(|n| n.next().unwrap()) + .collect::>() + }) + .collect() +} + +#[cfg(test)] +mod tests { + + use p3_baby_bear::BabyBear; + use p3_field::extension::BinomialExtensionField; + use p3_field::AbstractExtensionField; + use rand::{thread_rng, Rng}; + + use super::*; + + type F = BabyBear; + type EF = BinomialExtensionField; + + #[test] + fn test_powers_reducer() { + let mut rng = thread_rng(); + let alpha: EF = rng.gen(); + let n = 1000; + let sizes = [5, 110, 512, 999, 1000]; + let r = PowersReducer::::new(alpha, n); + + // check reduce_ext + for size in sizes { + let xs: Vec = (0..size).map(|_| rng.gen()).collect(); + assert_eq!( + r.reduce_ext(&xs), + xs.iter() + .enumerate() + .map(|(i, &x)| alpha.exp_u64(i as u64) * x) + .sum() + ); + } + + // check reduce_base + for size in sizes { + let xs: Vec = (0..size).map(|_| rng.gen()).collect(); + assert_eq!( + r.reduce_base(&xs), + xs.iter() + .enumerate() + .map(|(i, &x)| alpha.exp_u64(i as u64) * EF::from_base(x)) + .sum() + ); + } + + // bench reduce_base + /* + use core::hint::black_box; + use std::time::Instant; + let samples = 1_000; + for i in 0..5 { + let xs: Vec = (0..999).map(|_| rng.gen()).collect(); + let t0 = Instant::now(); + for _ in 0..samples { + black_box(r.reduce_base_slow(black_box(&xs))); + } + let dt_slow = t0.elapsed(); + let t0 = Instant::now(); + for _ in 0..samples { + black_box(r.reduce_base(black_box(&xs))); + } + let dt_fast = t0.elapsed(); + println!("sample {i}: slow: {dt_slow:?} fast: {dt_fast:?}"); + } + */ + } +} diff --git a/fri/src/verifier.rs b/fri/src/verifier.rs index 93109ae6..d7158d2e 100644 --- a/fri/src/verifier.rs +++ b/fri/src/verifier.rs @@ -2,149 +2,126 @@ use alloc::vec; use alloc::vec::Vec; use itertools::izip; -use p3_challenger::{CanObserve, CanSampleBits, FieldChallenger, GrindingChallenger}; +use p3_challenger::{CanObserve, CanSample, GrindingChallenger}; use p3_commit::Mmcs; -use p3_field::{AbstractField, Field, TwoAdicField}; +use p3_field::{Field, TwoAdicField}; use p3_matrix::Dimensions; -use p3_util::{log2_strict_usize, reverse_bits_len}; +use p3_util::reverse_bits_len; -use crate::{FriConfig, FriProof, InputOpening, QueryProof}; +use crate::{FriConfig, FriProof, QueryProof}; #[derive(Debug)] -pub enum VerificationError { +pub enum FriError { InvalidProofShape, - InputMmcsError(InputMmcsErr), CommitPhaseMmcsError(CommitMmcsErr), FinalPolyMismatch, InvalidPowWitness, } -pub type VerificationErrorForFriConfig = VerificationError< - <::InputMmcs as Mmcs<::Val>>::Error, - <::CommitPhaseMmcs as Mmcs<::Challenge>>::Error, ->; - -type VerificationResult = Result>; - -pub(crate) fn verify( - config: &FC, - input_mmcs: &[FC::InputMmcs], - input_dims: &[Vec], - input_commits: &[>::Commitment], - proof: &FriProof, - challenger: &mut FC::Challenger, -) -> VerificationResult { - let alpha: FC::Challenge = challenger.sample_ext_element(); - let betas: Vec = proof +#[derive(Debug)] +pub struct FriChallenges { + pub query_indices: Vec, + betas: Vec, +} + +pub fn verify_shape_and_sample_challenges( + config: &FriConfig, + proof: &FriProof, + challenger: &mut Challenger, +) -> Result, FriError> +where + F: Field, + M: Mmcs, + Challenger: GrindingChallenger + CanObserve + CanSample, +{ + let betas: Vec = proof .commit_phase_commits .iter() .map(|comm| { challenger.observe(comm.clone()); - challenger.sample_ext_element() + challenger.sample() }) .collect(); - if proof.query_proofs.len() != config.num_queries() { - return Err(VerificationError::InvalidProofShape); + if proof.query_proofs.len() != config.num_queries { + return Err(FriError::InvalidProofShape); } // Check PoW. - if !challenger.check_witness(config.proof_of_work_bits(), proof.pow_witness) { - return Err(VerificationError::InvalidPowWitness); + if !challenger.check_witness(config.proof_of_work_bits, proof.pow_witness) { + return Err(FriError::InvalidPowWitness); } - let log_max_height = proof.commit_phase_commits.len() + config.log_blowup(); + let log_max_height = proof.commit_phase_commits.len() + config.log_blowup; - for query_proof in &proof.query_proofs { - let index = challenger.sample_bits(log_max_height); + let query_indices: Vec = (0..config.num_queries) + .map(|_| challenger.sample_bits(log_max_height)) + .collect(); - let reduced_openings = verify_input( - input_mmcs, - input_commits, - input_dims, - &query_proof.input_openings, - index, - alpha, - log_max_height, - )?; + Ok(FriChallenges { + query_indices, + betas, + }) +} +pub fn verify_challenges( + config: &FriConfig, + proof: &FriProof, + challenges: &FriChallenges, + reduced_openings: &[[F; 32]], +) -> Result<(), FriError> +where + F: TwoAdicField, + M: Mmcs, +{ + let log_max_height = proof.commit_phase_commits.len() + config.log_blowup; + for (&index, query_proof, ro) in izip!( + &challenges.query_indices, + &proof.query_proofs, + reduced_openings + ) { let folded_eval = verify_query( config, &proof.commit_phase_commits, index, query_proof, - &betas, - reduced_openings, + &challenges.betas, + ro, log_max_height, )?; if folded_eval != proof.final_poly { - return Err(VerificationError::FinalPolyMismatch); + return Err(FriError::FinalPolyMismatch); } } Ok(()) } -fn verify_input( - input_mmcs: &[FC::InputMmcs], - input_commits: &[>::Commitment], - input_dims: &[Vec], - input_openings: &[InputOpening], - index: usize, - alpha: FC::Challenge, - log_max_height: usize, -) -> VerificationResult> { - let mut openings_by_log_height: Vec> = vec![vec![]; log_max_height + 1]; - for (mmcs, commit, dims, opening) in - izip!(input_mmcs, input_commits, input_dims, input_openings) - { - mmcs.verify_batch( - commit, - dims, - index, - &opening.opened_values, - &opening.opening_proof, - ) - .map_err(VerificationError::InputMmcsError)?; - for (mat_dims, mat_opening) in izip!(dims, &opening.opened_values) { - let log_height = log2_strict_usize(mat_dims.height); - openings_by_log_height[log_height].extend_from_slice(mat_opening); - } - } - let reduced_openings = openings_by_log_height - .into_iter() - .map(|o| { - o.into_iter() - .zip(alpha.powers()) - .map(|(y, alpha_pow)| alpha_pow * y) - .sum() - }) - .collect(); - Ok(reduced_openings) -} - -fn verify_query( - config: &FC, - commit_phase_commits: &[>::Commitment], +fn verify_query( + config: &FriConfig, + commit_phase_commits: &[M::Commitment], mut index: usize, - proof: &QueryProof, - betas: &[FC::Challenge], - reduced_openings: Vec, + proof: &QueryProof, + betas: &[F], + reduced_openings: &[F; 32], log_max_height: usize, -) -> VerificationResult { - let mut folded_eval = FC::Challenge::zero(); - let mut x = FC::Challenge::two_adic_generator(log_max_height) +) -> Result> +where + F: TwoAdicField, + M: Mmcs, +{ + let mut folded_eval = F::zero(); + let mut x = F::two_adic_generator(log_max_height) .exp_u64(reverse_bits_len(index, log_max_height) as u64); - for (log_folded_height, commit, step, &beta, reduced_opening_for_height) in izip!( + for (log_folded_height, commit, step, &beta) in izip!( (0..log_max_height).rev(), commit_phase_commits, &proof.commit_phase_openings, betas, - reduced_openings.into_iter().rev() ) { - folded_eval += reduced_opening_for_height; + folded_eval += reduced_openings[log_folded_height + 1]; let index_sibling = index ^ 1; let index_pair = index >> 1; @@ -157,7 +134,7 @@ fn verify_query( height: (1 << log_folded_height), }]; config - .commit_phase_mmcs() + .mmcs .verify_batch( commit, dims, @@ -165,10 +142,10 @@ fn verify_query( &[evals.clone()], &step.opening_proof, ) - .map_err(VerificationError::CommitPhaseMmcsError)?; + .map_err(FriError::CommitPhaseMmcsError)?; let mut xs = [x; 2]; - xs[index_sibling % 2] *= FC::Challenge::two_adic_generator(1); + xs[index_sibling % 2] *= F::two_adic_generator(1); // interpolate and evaluate at beta folded_eval = evals[0] + (beta - xs[0]) * (evals[1] - evals[0]) / (xs[1] - xs[0]); @@ -177,7 +154,7 @@ fn verify_query( } debug_assert!(index == 0 || index == 1); - debug_assert!(x.is_one() || x == FC::Challenge::two_adic_generator(1)); + debug_assert!(x.is_one() || x == F::two_adic_generator(1)); Ok(folded_eval) } diff --git a/fri/tests/fri.rs b/fri/tests/fri.rs index 522fcb0a..78579534 100644 --- a/fri/tests/fri.rs +++ b/fri/tests/fri.rs @@ -1,18 +1,18 @@ use itertools::Itertools; use p3_baby_bear::BabyBear; -use p3_challenger::{CanSample, DuplexChallenger}; -use p3_commit::{DirectMmcs, ExtensionMmcs}; +use p3_challenger::{CanSampleBits, DuplexChallenger, FieldChallenger}; +use p3_commit::ExtensionMmcs; use p3_dft::{Radix2Dit, TwoAdicSubgroupDft}; use p3_field::extension::BinomialExtensionField; use p3_field::{AbstractField, Field}; -use p3_fri::{FriConfigImpl, FriLdt}; -use p3_ldt::Ldt; +use p3_fri::{prover, verifier, FriConfig}; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::util::reverse_matrix_index_bits; -use p3_matrix::Matrix; +use p3_matrix::{Matrix, MatrixRows}; use p3_merkle_tree::FieldMerkleTreeMmcs; use p3_poseidon2::{DiffusionMatrixBabybear, Poseidon2}; use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; +use p3_util::log2_strict_usize; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha20Rng; @@ -25,43 +25,99 @@ type MyCompress = TruncatedPermutation; type ValMmcs = FieldMerkleTreeMmcs<::Packing, MyHash, MyCompress, 8>; type ChallengeMmcs = ExtensionMmcs; type Challenger = DuplexChallenger; -type MyFriConfig = FriConfigImpl; +type MyFriConfig = FriConfig; -fn get_ldt_for_testing(rng: &mut R) -> (Perm, ValMmcs, FriLdt) { +fn get_ldt_for_testing(rng: &mut R) -> (Perm, MyFriConfig) { let perm = Perm::new_from_rng(8, 22, DiffusionMatrixBabybear, rng); let hash = MyHash::new(perm.clone()); let compress = MyCompress::new(perm.clone()); - let val_mmcs = ValMmcs::new(hash, compress); - let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); - let fri_config = MyFriConfig::new(1, 10, 8, challenge_mmcs); - (perm, val_mmcs, FriLdt { config: fri_config }) + let mmcs = ChallengeMmcs::new(ValMmcs::new(hash, compress)); + let fri_config = FriConfig { + log_blowup: 1, + num_queries: 10, + proof_of_work_bits: 8, + mmcs, + }; + (perm, fri_config) } fn do_test_fri_ldt(rng: &mut R) { - let (perm, val_mmcs, ldt) = get_ldt_for_testing(rng); + let (perm, fc) = get_ldt_for_testing(rng); let dft = Radix2Dit::default(); - let ldes: Vec> = (3..6) + let shift = Val::generator(); + + let ldes: Vec> = (3..10) .map(|deg_bits| { - let evals = RowMajorMatrix::::rand_nonzero(rng, 1 << deg_bits, 4); - let mut lde = dft.coset_lde_batch(evals, 1, Val::one()); + let evals = RowMajorMatrix::::rand_nonzero(rng, 1 << deg_bits, 16); + let mut lde = dft.coset_lde_batch(evals, 1, shift); reverse_matrix_index_bits(&mut lde); lde }) .collect(); - let dims = ldes.iter().map(|m| m.dimensions()).collect_vec(); - let (comm, data) = val_mmcs.commit(ldes); - let mut p_challenger = Challenger::new(perm.clone()); - let proof = ldt.prove(&[val_mmcs.clone()], &[&data], &mut p_challenger); + let (proof, reduced_openings, p_sample) = { + // Prover world + let mut chal = Challenger::new(perm.clone()); + let alpha: Challenge = chal.sample_ext_element(); + + let input: [_; 32] = core::array::from_fn(|log_height| { + let matrices_with_log_height: Vec<&RowMajorMatrix> = ldes + .iter() + .filter(|m| log2_strict_usize(m.height()) == log_height) + .collect(); + if matrices_with_log_height.is_empty() { + None + } else { + let reduced: Vec = (0..(1 << log_height)) + .map(|r| { + alpha + .powers() + .zip(matrices_with_log_height.iter().flat_map(|m| m.row(r))) + .map(|(alpha_pow, v)| alpha_pow * v) + .sum() + }) + .collect(); + Some(reduced) + } + }); + + let (proof, idxs) = prover::prove(&fc, &input, &mut chal); + + let log_max_height = input.iter().rposition(Option::is_some).unwrap(); + let reduced_openings: Vec<[Challenge; 32]> = idxs + .into_iter() + .map(|idx| { + input + .iter() + .enumerate() + .map(|(log_height, v)| { + if let Some(v) = v { + v[idx >> (log_max_height - log_height)] + } else { + Challenge::zero() + } + }) + .collect_vec() + .try_into() + .unwrap() + }) + .collect(); + + (proof, reduced_openings, chal.sample_bits(8)) + }; let mut v_challenger = Challenger::new(perm); - ldt.verify(&[val_mmcs], &[dims], &[comm], &proof, &mut v_challenger) - .expect("verification failed"); + let _alpha: Challenge = v_challenger.sample_ext_element(); + let fri_challenges = + verifier::verify_shape_and_sample_challenges(&fc, &proof, &mut v_challenger) + .expect("failed verify shape and sample"); + verifier::verify_challenges(&fc, &proof, &fri_challenges, &reduced_openings) + .expect("failed verify challenges"); assert_eq!( - p_challenger.sample(), - v_challenger.sample(), + p_sample, + v_challenger.sample_bits(8), "prover and verifier transcript have same state after FRI" ); } diff --git a/fri/tests/pcs.rs b/fri/tests/pcs.rs new file mode 100644 index 00000000..6b45961f --- /dev/null +++ b/fri/tests/pcs.rs @@ -0,0 +1,116 @@ +use p3_baby_bear::BabyBear; +use p3_challenger::{CanObserve, DuplexChallenger, FieldChallenger}; +use p3_commit::{ExtensionMmcs, Pcs, UnivariatePcs}; +use p3_dft::Radix2DitParallel; +use p3_field::extension::BinomialExtensionField; +use p3_field::Field; +use p3_fri::{FriConfig, TwoAdicFriPcs, TwoAdicFriPcsConfig}; +use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::Matrix; +use p3_merkle_tree::FieldMerkleTreeMmcs; +use p3_poseidon2::{DiffusionMatrixBabybear, Poseidon2}; +use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; +use rand::thread_rng; + +fn make_test_fri_pcs(log_degrees: &[usize]) { + let mut rng = thread_rng(); + type Val = BabyBear; + type Challenge = BinomialExtensionField; + + type Perm = Poseidon2; + let perm = Perm::new_from_rng(8, 22, DiffusionMatrixBabybear, &mut rng); + + type MyHash = PaddingFreeSponge; + let hash = MyHash::new(perm.clone()); + + type MyCompress = TruncatedPermutation; + let compress = MyCompress::new(perm.clone()); + + type ValMmcs = FieldMerkleTreeMmcs<::Packing, MyHash, MyCompress, 8>; + let val_mmcs = ValMmcs::new(hash, compress); + + type ChallengeMmcs = ExtensionMmcs; + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + + type Dft = Radix2DitParallel; + let dft = Dft {}; + + type Challenger = DuplexChallenger; + + let fri_config = FriConfig { + log_blowup: 1, + num_queries: 10, + proof_of_work_bits: 8, + mmcs: challenge_mmcs, + }; + type Pcs = + TwoAdicFriPcs>; + let pcs = Pcs::new(fri_config, dft, val_mmcs); + + let mut challenger = Challenger::new(perm.clone()); + + let polynomials = log_degrees + .iter() + .map(|d| RowMajorMatrix::rand(&mut rng, 1 << *d, 10)) + .collect::>(); + + let (commit, data) = pcs.commit_batches(polynomials.clone()); + + challenger.observe(commit); + + let zeta = challenger.sample_ext_element::(); + + let points = polynomials.iter().map(|_| vec![zeta]).collect::>(); + + let (opening, proof) = , _>>::open_multi_batches( + &pcs, + &[(&data, &points)], + &mut challenger, + ); + + // verify the proof. + let mut challenger = Challenger::new(perm); + challenger.observe(commit); + let _ = challenger.sample_ext_element::(); + let dims = polynomials + .iter() + .map(|p| p.dimensions()) + .collect::>(); + , _>>::verify_multi_batches( + &pcs, + &[(commit, &points)], + &[dims], + opening, + &proof, + &mut challenger, + ) + .expect("verification error"); +} + +#[test] +fn test_fri_pcs_single() { + make_test_fri_pcs(&[3]); +} + +#[test] +fn test_fri_pcs_many_equal() { + for i in 1..4 { + make_test_fri_pcs(&[i; 5]); + } +} + +#[test] +fn test_fri_pcs_many_different() { + for i in 2..4 { + let degrees = (3..3 + i).collect::>(); + make_test_fri_pcs(°rees); + } +} + +#[test] +fn test_fri_pcs_many_different_rev() { + for i in 2..4 { + let degrees = (3..3 + i).rev().collect::>(); + make_test_fri_pcs(°rees); + } +} diff --git a/keccak-air/Cargo.toml b/keccak-air/Cargo.toml index c28f5b12..f3474019 100644 --- a/keccak-air/Cargo.toml +++ b/keccak-air/Cargo.toml @@ -19,7 +19,7 @@ p3-dft = { path = "../dft" } p3-fri = { path = "../fri" } p3-goldilocks = { path = "../goldilocks" } p3-keccak = { path = "../keccak" } -p3-ldt = { path = "../ldt" } +p3-maybe-rayon = { path = "../maybe-rayon" } p3-mds = { path = "../mds" } p3-merkle-tree = { path = "../merkle-tree" } p3-poseidon2 = { path = "../poseidon2" } @@ -37,3 +37,6 @@ name = "prove_baby_bear_poseidon2" [[example]] name = "prove_goldilocks_keccak" + +[features] +parallel = ["p3-maybe-rayon/parallel"] diff --git a/keccak-air/examples/prove_baby_bear_keccak.rs b/keccak-air/examples/prove_baby_bear_keccak.rs index dd5e8474..ea43473b 100644 --- a/keccak-air/examples/prove_baby_bear_keccak.rs +++ b/keccak-air/examples/prove_baby_bear_keccak.rs @@ -4,14 +4,13 @@ use p3_commit::ExtensionMmcs; use p3_dft::Radix2DitParallel; use p3_field::extension::BinomialExtensionField; use p3_field::Field; -use p3_fri::{FriBasedPcs, FriConfigImpl, FriLdt}; +use p3_fri::{FriConfig, TwoAdicFriPcs, TwoAdicFriPcsConfig}; use p3_keccak::Keccak256Hash; use p3_keccak_air::{generate_trace_rows, KeccakAir}; -use p3_ldt::QuotientMmcs; use p3_merkle_tree::FieldMerkleTreeMmcs; use p3_poseidon2::{DiffusionMatrixBabybear, Poseidon2}; use p3_symmetric::{CompressionFunctionFromHasher, SerializingHasher32}; -use p3_uni_stark::{prove, verify, StarkConfigImpl, VerificationError}; +use p3_uni_stark::{prove, verify, StarkConfig, VerificationError}; use rand::{random, thread_rng}; use tracing_forest::util::LevelFilter; use tracing_forest::ForestLayer; @@ -56,16 +55,19 @@ fn main() -> Result<(), VerificationError> { type Challenger = DuplexChallenger; - type Quotient = QuotientMmcs; - type MyFriConfig = FriConfigImpl; - let fri_config = MyFriConfig::new(1, 100, 16, challenge_mmcs); - let ldt = FriLdt { config: fri_config }; + let fri_config = FriConfig { + log_blowup: 1, + num_queries: 100, + proof_of_work_bits: 16, + mmcs: challenge_mmcs, + }; + type Pcs = + TwoAdicFriPcs>; + let pcs = Pcs::new(fri_config, dft, val_mmcs); - type Pcs = FriBasedPcs; - type MyConfig = StarkConfigImpl; + type MyConfig = StarkConfig; + let config = StarkConfig::new(pcs); - let pcs = Pcs::new(dft, val_mmcs, ldt); - let config = StarkConfigImpl::new(pcs); let mut challenger = Challenger::new(perm.clone()); let inputs = (0..NUM_HASHES).map(|_| random()).collect::>(); diff --git a/keccak-air/examples/prove_baby_bear_poseidon2.rs b/keccak-air/examples/prove_baby_bear_poseidon2.rs index db6dc519..7308eb34 100644 --- a/keccak-air/examples/prove_baby_bear_poseidon2.rs +++ b/keccak-air/examples/prove_baby_bear_poseidon2.rs @@ -4,13 +4,12 @@ use p3_commit::ExtensionMmcs; use p3_dft::Radix2DitParallel; use p3_field::extension::BinomialExtensionField; use p3_field::Field; -use p3_fri::{FriBasedPcs, FriConfigImpl, FriLdt}; +use p3_fri::{FriConfig, TwoAdicFriPcs, TwoAdicFriPcsConfig}; use p3_keccak_air::{generate_trace_rows, KeccakAir}; -use p3_ldt::QuotientMmcs; use p3_merkle_tree::FieldMerkleTreeMmcs; use p3_poseidon2::{DiffusionMatrixBabybear, Poseidon2}; use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; -use p3_uni_stark::{prove, verify, StarkConfigImpl, VerificationError}; +use p3_uni_stark::{prove, verify, StarkConfig, VerificationError}; use rand::{random, thread_rng}; use tracing_forest::util::LevelFilter; use tracing_forest::ForestLayer; @@ -55,16 +54,19 @@ fn main() -> Result<(), VerificationError> { type Challenger = DuplexChallenger; - type Quotient = QuotientMmcs; - type MyFriConfig = FriConfigImpl; - let fri_config = MyFriConfig::new(1, 100, 16, challenge_mmcs); - let ldt = FriLdt { config: fri_config }; + let fri_config = FriConfig { + log_blowup: 1, + num_queries: 100, + proof_of_work_bits: 16, + mmcs: challenge_mmcs, + }; + type Pcs = + TwoAdicFriPcs>; + let pcs = Pcs::new(fri_config, dft, val_mmcs); - type Pcs = FriBasedPcs; - type MyConfig = StarkConfigImpl; + type MyConfig = StarkConfig; + let config = StarkConfig::new(pcs); - let pcs = Pcs::new(dft, val_mmcs, ldt); - let config = StarkConfigImpl::new(pcs); let mut challenger = Challenger::new(perm.clone()); let inputs = (0..NUM_HASHES).map(|_| random()).collect::>(); diff --git a/keccak-air/examples/prove_goldilocks_keccak.rs b/keccak-air/examples/prove_goldilocks_keccak.rs index 4467adbc..fa8fe9ac 100644 --- a/keccak-air/examples/prove_goldilocks_keccak.rs +++ b/keccak-air/examples/prove_goldilocks_keccak.rs @@ -3,15 +3,14 @@ use p3_commit::ExtensionMmcs; use p3_dft::Radix2DitParallel; use p3_field::extension::BinomialExtensionField; use p3_field::Field; -use p3_fri::{FriBasedPcs, FriConfigImpl, FriLdt}; +use p3_fri::{FriConfig, TwoAdicFriPcs, TwoAdicFriPcsConfig}; use p3_goldilocks::Goldilocks; use p3_keccak::Keccak256Hash; use p3_keccak_air::{generate_trace_rows, KeccakAir}; -use p3_ldt::QuotientMmcs; use p3_merkle_tree::FieldMerkleTreeMmcs; use p3_poseidon2::{DiffusionMatrixGoldilocks, Poseidon2}; use p3_symmetric::{CompressionFunctionFromHasher, SerializingHasher64}; -use p3_uni_stark::{prove, verify, StarkConfigImpl, VerificationError}; +use p3_uni_stark::{prove, verify, StarkConfig, VerificationError}; use rand::{random, thread_rng}; use tracing_forest::util::LevelFilter; use tracing_forest::ForestLayer; @@ -55,16 +54,19 @@ fn main() -> Result<(), VerificationError> { type Challenger = DuplexChallenger; - type Quotient = QuotientMmcs; - type MyFriConfig = FriConfigImpl; - let fri_config = MyFriConfig::new(1, 100, 16, challenge_mmcs); - let ldt = FriLdt { config: fri_config }; + let fri_config = FriConfig { + log_blowup: 1, + num_queries: 100, + proof_of_work_bits: 16, + mmcs: challenge_mmcs, + }; + type Pcs = + TwoAdicFriPcs>; + let pcs = Pcs::new(fri_config, dft, val_mmcs); - type Pcs = FriBasedPcs; - type MyConfig = StarkConfigImpl; + type MyConfig = StarkConfig; + let config = StarkConfig::new(pcs); - let pcs = Pcs::new(dft, val_mmcs, ldt); - let config = StarkConfigImpl::new(pcs); let mut challenger = Challenger::new(perm.clone()); let inputs = (0..NUM_HASHES).map(|_| random()).collect::>(); diff --git a/ldt/Cargo.toml b/ldt/Cargo.toml deleted file mode 100644 index 3183969a..00000000 --- a/ldt/Cargo.toml +++ /dev/null @@ -1,25 +0,0 @@ -[package] -name = "p3-ldt" -version = "0.1.0" -edition = "2021" -license = "MIT OR Apache-2.0" - -[dependencies] -p3-challenger = { path = "../challenger" } -p3-commit = { path = "../commit" } -p3-dft = { path = "../dft" } -p3-field = { path = "../field" } -p3-interpolation = { path = "../interpolation" } -p3-matrix = { path = "../matrix" } -p3-util = { path = "../util" } - -itertools = "0.12.0" -tracing = "0.1.37" -serde = { version = "1.0", default-features = false } - -[dev-dependencies] -p3-baby-bear = { path = "../baby-bear" } -p3-symmetric = { path = "../symmetric" } -p3-merkle-tree = { path = "../merkle-tree" } -p3-blake3 = { path = "../blake3" } -rand = "0.8.5" diff --git a/ldt/src/ldt_based_pcs.rs b/ldt/src/ldt_based_pcs.rs deleted file mode 100644 index 7ccd157c..00000000 --- a/ldt/src/ldt_based_pcs.rs +++ /dev/null @@ -1,278 +0,0 @@ -use alloc::vec::Vec; -use core::marker::PhantomData; - -use itertools::Itertools; -use p3_challenger::FieldChallenger; -use p3_commit::{ - DirectMmcs, OpenedValues, OpenedValuesForMatrix, OpenedValuesForPoint, OpenedValuesForRound, - Pcs, UnivariatePcs, UnivariatePcsWithLde, -}; -use p3_dft::TwoAdicSubgroupDft; -use p3_field::extension::HasFrobenius; -use p3_field::{ExtensionField, TwoAdicField}; -use p3_interpolation::interpolate_coset; -use p3_matrix::bitrev::{BitReversableMatrix, BitReversedMatrixView}; -use p3_matrix::dense::RowMajorMatrixView; -use p3_matrix::{Dimensions, Matrix, MatrixRows}; -use tracing::{info_span, instrument}; - -use crate::quotient::QuotientMmcs; -use crate::{Ldt, Opening}; - -pub struct LdtBasedPcs { - dft: Dft, - mmcs: M, - ldt: L, - _phantom: PhantomData<(Val, EF, Challenger)>, -} - -impl LdtBasedPcs { - pub fn new(dft: Dft, mmcs: M, ldt: L) -> Self { - Self { - dft, - mmcs, - ldt, - _phantom: PhantomData, - } - } -} - -impl UnivariatePcsWithLde - for LdtBasedPcs -where - Val: TwoAdicField, - EF: ExtensionField + TwoAdicField + HasFrobenius, - In: MatrixRows, - Dft: TwoAdicSubgroupDft, - M: 'static + for<'a> DirectMmcs = RowMajorMatrixView<'a, Val>>, - L: Ldt, Challenger>, - Challenger: FieldChallenger, -{ - type Lde<'a> = BitReversedMatrixView> where Self: 'a; - - fn coset_shift(&self) -> Val { - Val::generator() - } - - fn log_blowup(&self) -> usize { - self.ldt.log_blowup() - } - - fn get_ldes<'a, 'b>(&'a self, prover_data: &'b Self::ProverData) -> Vec> - where - 'a: 'b, - { - // We committed to the bit-reversed LDE, so now we wrap it to return in natural order. - self.mmcs - .get_matrices(prover_data) - .into_iter() - .map(|m| BitReversedMatrixView::new(m)) - .collect() - } - - fn commit_shifted_batches( - &self, - polynomials: Vec, - coset_shifts: &[Val], - ) -> (Self::Commitment, Self::ProverData) { - let ldes = info_span!("compute all coset LDEs").in_scope(|| { - polynomials - .into_iter() - .zip_eq(coset_shifts.iter()) - .map(|(poly, coset_shift)| { - let shift = Val::generator() / *coset_shift; - let input = poly.to_row_major_matrix(); - // Commit to the bit-reversed LDE. - self.dft - .coset_lde_batch(input, self.ldt.log_blowup(), shift) - .bit_reverse_rows() - .to_row_major_matrix() - }) - .collect() - }); - self.mmcs.commit(ldes) - } -} - -impl Pcs - for LdtBasedPcs -where - Val: TwoAdicField, - EF: ExtensionField + TwoAdicField + HasFrobenius, - In: MatrixRows, - Dft: TwoAdicSubgroupDft, - M: 'static + for<'a> DirectMmcs = RowMajorMatrixView<'a, Val>>, - L: Ldt, Challenger>, - Challenger: FieldChallenger, -{ - type Commitment = M::Commitment; - type ProverData = M::ProverData; - type Proof = L::Proof; - type Error = L::Error; - - fn commit_batches(&self, polynomials: Vec) -> (Self::Commitment, Self::ProverData) { - let shifts = (0..polynomials.len()) - .map(|_| Val::one()) - .collect::>(); - self.commit_shifted_batches(polynomials, &shifts) - } -} - -impl UnivariatePcs - for LdtBasedPcs -where - Val: TwoAdicField, - EF: ExtensionField + TwoAdicField + HasFrobenius, - In: MatrixRows, - Dft: TwoAdicSubgroupDft, - M: 'static + for<'a> DirectMmcs = RowMajorMatrixView<'a, Val>>, - L: Ldt, Challenger>, - Challenger: FieldChallenger, -{ - #[instrument(name = "prove batch opening", skip_all)] - fn open_multi_batches( - &self, - prover_data_and_points: &[(&Self::ProverData, &[Vec])], - challenger: &mut Challenger, - ) -> (OpenedValues, Self::Proof) { - let coset_shift: Val = - >::coset_shift(self); - - // Use Barycentric interpolation to evaluate each matrix at a given point. - let eval_at_points = |matrix: M::Mat<'_>, points: &[EF]| { - points - .iter() - .map(|&point| { - let (low_coset, _) = - matrix.split_rows(matrix.height() >> self.ldt.log_blowup()); - interpolate_coset(&BitReversedMatrixView::new(low_coset), coset_shift, point) - }) - .collect::>() - }; - - let all_opened_values = info_span!("compute opened values with Lagrange interpolation") - .in_scope(|| { - prover_data_and_points - .iter() - .map(|(data, points_per_matrix)| { - let matrices = self.mmcs.get_matrices(data); - - matrices - .iter() - .zip(*points_per_matrix) - .map(|(&mat, points_for_mat)| eval_at_points(mat, points_for_mat)) - .collect::>() - }) - .collect::>() - }); - - let (prover_data, all_points): (Vec<_>, Vec<_>) = - prover_data_and_points.iter().copied().unzip(); - - let quotient_mmcs = all_points - .into_iter() - .zip(&all_opened_values) - .map( - |(points_for_round, opened_values_for_round_by_matrix): ( - &[Vec], - &OpenedValuesForRound, - )| { - // let opened_values_for_round_by_point = - // transpose(opened_values_for_round_by_matrix.to_vec()); - - let openings = opened_values_for_round_by_matrix - .iter() - .zip(points_for_round) - .map( - |(opened_values_for_matrix, points_for_matrix): ( - &OpenedValuesForMatrix, - &Vec, - )| { - opened_values_for_matrix - .iter() - .zip(points_for_matrix) - .map( - |(opened_values_for_point, &point): ( - &OpenedValuesForPoint, - &EF, - )| { - Opening::::new( - point, - opened_values_for_point.clone(), - ) - }, - ) - .collect::>>() - }, - ) - .collect(); - QuotientMmcs:: { - inner: self.mmcs.clone(), - openings, - coset_shift, - _phantom: PhantomData, - } - }, - ) - .collect_vec(); - - let proof = self.ldt.prove("ient_mmcs, &prover_data, challenger); - (all_opened_values, proof) - } - - fn verify_multi_batches( - &self, - commits_and_points: &[(Self::Commitment, &[Vec])], - dims: &[Vec], - values: OpenedValues, - proof: &Self::Proof, - challenger: &mut Challenger, - ) -> Result<(), Self::Error> { - let (commits, points): (Vec, Vec<&[Vec]>) = - commits_and_points.iter().cloned().unzip(); - let coset_shift: Val = - >::coset_shift(self); - let (dims, quotient_mmcs): (Vec<_>, Vec<_>) = points - .into_iter() - .zip_eq(values) - .zip_eq(dims) - .map( - #[allow(clippy::type_complexity)] - |((points, opened_values_for_round_by_matrix), dims): ( - (&[Vec], OpenedValuesForRound), - &Vec, - )| { - let openings = opened_values_for_round_by_matrix - .into_iter() - .zip(points) - .map(|(opened_values_for_matrix_by_point, points_for_matrix)| { - points_for_matrix - .iter() - .zip(opened_values_for_matrix_by_point) - .map(|(&point, opened_values_for_point)| { - Opening::::new(point, opened_values_for_point) - }) - .collect() - }) - .collect_vec(); - ( - dims.iter() - .map(|d| Dimensions { - width: d.width * openings.len(), - height: d.height << self.ldt.log_blowup(), - }) - .collect_vec(), - QuotientMmcs:: { - inner: self.mmcs.clone(), - openings, - coset_shift, - _phantom: PhantomData, - }, - ) - }, - ) - .unzip(); - self.ldt - .verify("ient_mmcs, &dims, &commits, proof, challenger) - } -} diff --git a/ldt/src/lib.rs b/ldt/src/lib.rs deleted file mode 100644 index 68886439..00000000 --- a/ldt/src/lib.rs +++ /dev/null @@ -1,53 +0,0 @@ -//! This crate contains a framework for low-degree tests (LDTs). - -#![no_std] - -mod ldt_based_pcs; -mod quotient; - -use alloc::vec::Vec; - -pub use ldt_based_pcs::*; -use p3_challenger::FieldChallenger; -use p3_commit::Mmcs; -use p3_field::Field; -use p3_matrix::Dimensions; -pub use quotient::*; -use serde::de::DeserializeOwned; -use serde::Serialize; - -extern crate alloc; - -/// A batch low-degree test (LDT). -pub trait Ldt -where - Val: Field, - M: Mmcs, - Challenger: FieldChallenger, -{ - type Proof: Serialize + DeserializeOwned; - type Error; - - fn log_blowup(&self) -> usize; - - fn blowup(&self) -> usize { - 1 << self.log_blowup() - } - - /// Prove that each column of each matrix in `codewords` is a codeword. - fn prove( - &self, - input_mmcs: &[M], - input_data: &[&M::ProverData], - challenger: &mut Challenger, - ) -> Self::Proof; - - fn verify( - &self, - input_mmcs: &[M], - input_dims: &[Vec], - input_commits: &[M::Commitment], - proof: &Self::Proof, - challenger: &mut Challenger, - ) -> Result<(), Self::Error>; -} diff --git a/ldt/src/quotient.rs b/ldt/src/quotient.rs deleted file mode 100644 index 12900f49..00000000 --- a/ldt/src/quotient.rs +++ /dev/null @@ -1,527 +0,0 @@ -use alloc::vec; -use alloc::vec::Vec; -use core::fmt::Debug; -use core::marker::PhantomData; -use core::mem::MaybeUninit; -use core::{debug_assert_eq, slice}; - -use itertools::{izip, Itertools}; -use p3_commit::Mmcs; -use p3_field::extension::HasFrobenius; -use p3_field::{ - add_vecs, batch_multiplicative_inverse, binomial_expand, cyclic_subgroup_coset_known_order, - eval_poly, scale_vec, Field, PackedField, TwoAdicField, -}; -use p3_matrix::dense::RowMajorMatrix; -use p3_matrix::{Dimensions, Matrix, MatrixRowSlices, MatrixRows}; -use p3_util::{log2_strict_usize, reverse_bits_len, reverse_slice_index_bits}; - -/// A wrapper around an Inner MMCS, which transforms each inner value to (inner - r(X)) / m(X), -/// where m(X) is the minimal polynomial of the opening point, and r(X) = inner mod m(X). -/// -/// This technique was proposed by Liam Eagen. -/// Instead of providing the quotient (p(X) - p(alpha))/(X - alpha) in the extension field, -/// we express the value of p(X) at X=alpha as the remainder r(X) = p(X) mod m(X), -/// where m(X) is the minimal polynomial such that m(alpha) = 0, -/// and prove r(X) is correct by showing (p(X) - r(X)) is divisible by m(X). -/// -/// This has the benefit that all coefficients and evaluations are performed in the base field. -/// -/// Since we have the values p(alpha) = y, we can recover r(X) by interpolating -/// [(alpha,y), (Frob alpha, Frob y), (Frob^2 alpha, Frob^2 y), ..] -/// since the Galois action commutes with polynomials with coefficients over the base field. -/// -/// Since there can be multiple opening points, for each matrix, this transforms an inner opened row -/// into a concatenation of rows, transformed as above, for each point. -#[derive(Clone)] -pub struct QuotientMmcs> { - pub(crate) inner: Inner, - - /// For each matrix, a list of claimed openings, one for each point that we open that batch of - /// polynomials at. - pub(crate) openings: Vec>>, - - // The coset shift for the inner MMCS's evals, to correct `x` in the denominator. - pub(crate) coset_shift: F, - - // QuotientMmcs and Opening, once constructed, technically do not need to know - // anything about the extension field. However, we keep it as a generic so that - // we can unroll the inner loop of `compute_quotient_matrix_row` over EF::D. - pub(crate) _phantom: PhantomData, -} - -/// A claimed opening. -#[derive(Clone, Debug)] -pub(crate) struct Opening { - // point.minimal_poly() - pub(crate) minpoly: Vec, - // for each column, the remainder poly r(X) = p(X) mod m(X) - pub(crate) remainder_polys: Vec>, - - // each remainder poly always has degree EF::D. - // so, the width of this matrix is EF::D, and the height is - // `openings.len() // F::Packing::WIDTH`. - // this matrix is missing any remaining coefficients that don't divide - // F::Packing::WIDTH evenly, so you have to get those from remainder_polys. - pub(crate) r_vertically_packed: Option>, - - pub(crate) _phantom: PhantomData, -} - -impl> Opening { - pub(crate) fn new(point: EF, values: Vec) -> Self { - let remainder_polys = Self::compute_remainder_polys(point, &values); - let r_vertically_packed = vertical_pack(&remainder_polys); - Self { - minpoly: point.minimal_poly(), - remainder_polys, - r_vertically_packed, - _phantom: PhantomData, - } - } - fn compute_remainder_polys(point: EF, values: &[EF]) -> Vec> { - // compute lagrange basis for [point, Frob point, Frob^2 point, ..] - let xs = point.galois_group(); - debug_assert_eq!(xs.len(), EF::D); - let w = xs[1..] - .iter() - .map(|&xi| xs[0] - xi) - .product::() - .inverse(); - let l_point = scale_vec(w, binomial_expand(&xs[1..])); - debug_assert_eq!(l_point.len(), EF::D); - // interpolate at [(pt,value),(Frob pt, Frob alpha),..] - let mut rs = vec![]; - for &v in values { - let mut l_point_frob = scale_vec(v, l_point.clone()); - let mut r = l_point_frob.clone(); - for _ in 1..EF::D { - l_point_frob.iter_mut().for_each(|c| *c = c.frobenius()); - r = add_vecs(r, l_point_frob.clone()); - } - rs.push( - r.into_iter() - .map(|c| c.as_base().expect("Extension is not algebraic?")) - .collect(), - ); - } - rs - } -} - -/// For input `[ -/// [ 1, 2, 3], \ pack -/// [ 4, 5, 6], / -/// [ 7, 8, 9], \ pack -/// [10,11,12], / -/// [13,14,15], -/// ]`, -/// and F::Packing::WIDTH = 2, returns `[ -/// [P(1, 4), P(2, 5), P(3, 6)], -/// [P(7,10), P(8,11), P(9,12)], -/// ]` -/// where P(..) is a packed field. Trailing values (`[13,14,15]` above) are ignored. -fn vertical_pack(polys: &[Vec]) -> Option> { - let width = polys[0].len(); - let height = polys.len() / F::Packing::WIDTH; - if height == 0 { - return None; - } - Some(RowMajorMatrix::new( - (0..height) - .flat_map(|r| { - (0..width) - .map(move |c| F::Packing::from_fn(move |i| polys[r * F::Packing::WIDTH + i][c])) - }) - .collect(), - width, - )) -} - -#[derive(Debug, PartialEq, Eq)] -pub enum QuotientError { - InnerMmcs(InnerMmcsError), - OriginalValueMismatch, -} - -impl Mmcs for QuotientMmcs -where - F: TwoAdicField, - EF: HasFrobenius, - Inner: Mmcs, - for<'a> Inner::Mat<'a>: MatrixRowSlices, -{ - type ProverData = Inner::ProverData; - type Commitment = Inner::Commitment; - type Proof = Inner::Proof; - type Error = QuotientError; - type Mat<'a> = QuotientMatrix> where Self: 'a; - - fn open_batch( - &self, - index: usize, - prover_data: &Self::ProverData, - ) -> (Vec>, Self::Proof) { - let (inner_values, proof) = self.inner.open_batch(index, prover_data); - let matrix_heights = self.inner.get_matrix_heights(prover_data); - let max_height = *matrix_heights.iter().max().unwrap(); - let log_max_height = log2_strict_usize(max_height); - - let quotients = izip!(inner_values, &self.openings, matrix_heights) - .map(|(inner_row, openings_for_mat, height)| { - let log_height = log2_strict_usize(height); - let bits_reduced = log_max_height - log_height; - let reduced_index = reverse_bits_len(index >> bits_reduced, log_height); - let x = self.coset_shift - * F::two_adic_generator(log_height).exp_u64(reduced_index as u64); - - let m_invs = batch_multiplicative_inverse( - &openings_for_mat - .iter() - .map(|opening| eval_poly(&opening.minpoly, x)) - .collect_vec(), - ); - compute_quotient_matrix_row(x, openings_for_mat, &m_invs, &inner_row) - }) - .collect(); - - (quotients, proof) - } - - fn get_matrices<'a>(&'a self, prover_data: &'a Self::ProverData) -> Vec> { - self.inner - .get_matrices(prover_data) - .into_iter() - .zip(self.openings.clone()) - .map(|(inner, openings)| { - let height = inner.height(); - let log_height = log2_strict_usize(height); - let g = F::two_adic_generator(log_height); - let mut subgroup = - cyclic_subgroup_coset_known_order(g, self.coset_shift, height).collect_vec(); - reverse_slice_index_bits(&mut subgroup); - - let denominators: Vec = subgroup - .iter() - .flat_map(|&x| { - openings - .iter() - .map(move |opening| eval_poly(&opening.minpoly, x)) - }) - .collect(); - let inv_denominators = RowMajorMatrix::new( - batch_multiplicative_inverse(&denominators), - openings.len(), - ); - - QuotientMatrix { - inner, - subgroup, - openings, - inv_denominators, - _phantom: PhantomData, - } - }) - .collect() - } - - fn verify_batch( - &self, - commit: &Self::Commitment, - dimensions: &[Dimensions], - index: usize, - opened_quotient_values: &[Vec], - proof: &Self::Proof, - ) -> Result<(), Self::Error> { - // quotient = (original - r(X))/m(X) - // original = quotient * m(X) + r(X) - - let log_max_height = dimensions - .iter() - .map(|dims| log2_strict_usize(dims.height)) - .max() - .unwrap(); - - let opened_original_values = izip!(opened_quotient_values, &self.openings, dimensions) - .map(|(quotient_row, openings, dims)| { - let log_height = log2_strict_usize(dims.height); - let bits_reduced = log_max_height - log_height; - let reduced_index = reverse_bits_len(index >> bits_reduced, log_height); - let x = self.coset_shift - * F::two_adic_generator(log_height).exp_u64(reduced_index as u64); - - let original_width = quotient_row.len() / openings.len(); - let original_row_repeated: Vec> = quotient_row - .chunks(original_width) - .zip(openings) - .map(|(quotient_row_chunk, opening)| { - quotient_row_chunk - .iter() - .zip(&opening.remainder_polys) - .map(|("ient_value, r)| { - quotient_value * eval_poly(&opening.minpoly, x) + eval_poly(r, x) - }) - .collect_vec() - }) - .collect_vec(); - get_repeated(original_row_repeated.into_iter()) - .ok_or(QuotientError::OriginalValueMismatch) - }) - .collect::, Self::Error>>()?; - - self.inner - .verify_batch(commit, dimensions, index, &opened_original_values, proof) - .map_err(QuotientError::InnerMmcs) - } -} - -#[derive(Clone)] -pub struct QuotientMatrix> { - inner: Inner, - subgroup: Vec, - openings: Vec>, - /// For each row (associated with a subgroup element `x`), for each opening point, - /// this holds `1 / m(X)`. - inv_denominators: RowMajorMatrix, - _phantom: PhantomData, -} - -impl> Matrix for QuotientMatrix { - fn width(&self) -> usize { - self.inner.width() * self.openings.len() - } - - fn height(&self) -> usize { - self.inner.height() - } -} - -impl, Inner: MatrixRowSlices> MatrixRows - for QuotientMatrix -{ - type Row<'a> = Vec where Inner: 'a; - - #[inline] - fn row(&self, r: usize) -> Self::Row<'_> { - compute_quotient_matrix_row( - self.subgroup[r], - &self.openings, - self.inv_denominators.row_slice(r), - self.inner.row_slice(r), - ) - } - - #[inline] - fn row_vec(&self, r: usize) -> Vec { - self.row(r) - } -} - -fn compute_quotient_matrix_row>( - x: F, - openings: &[Opening], - m_invs: &[F], - inner_row: &[F], -) -> Vec { - let mut qp_ys: Vec = Vec::with_capacity(inner_row.len() * openings.len()); - - // [P(1,1,1,1), P(x,x,x,x),P(x^2,x^2,x^2,x^2),..] - let packed_x_pows = x - .powers() - .take(EF::D) - .map(|x_pow| F::Packing::from(x_pow)) - .collect_vec(); - - for (opening, &m_inv) in openings.iter().zip(m_invs) { - let packed_m_inv = F::Packing::from(m_inv); - let (packed_ys, sfx_ys) = F::Packing::pack_slice_with_suffix(inner_row); - - if let Some(r_vertically_packed) = &opening.r_vertically_packed { - let uninit = qp_ys.spare_capacity_mut(); - assert!(uninit.len() >= packed_ys.len() * F::Packing::WIDTH); - let packed_uninit = unsafe { - slice::from_raw_parts_mut( - uninit.as_mut_ptr().cast::>(), - packed_ys.len(), - ) - }; - for (packed_qp_y, &packed_y, coeffs) in - izip!(packed_uninit, packed_ys, r_vertically_packed.rows()) - { - let mut r_at_x = coeffs[0]; - for i in 1..EF::D { - r_at_x += coeffs[i] * packed_x_pows[i]; - } - let qp_y = (packed_y - r_at_x) * packed_m_inv; - packed_qp_y.write(qp_y); - } - unsafe { - qp_ys.set_len(qp_ys.len() + packed_ys.len() * F::Packing::WIDTH); - } - } - - sfx_ys - .iter() - .zip(&opening.remainder_polys[opening.remainder_polys.len() - sfx_ys.len()..]) - .for_each(|(&y, r)| qp_ys.push((y - eval_poly(r, x)) * m_inv)); - } - - qp_ys -} - -/// Checks that the given iterator contains repetitions of a single item, and return that item. -fn get_repeated>(mut iter: I) -> Option { - let first = iter.next().expect("get_repeated on empty iterator"); - for x in iter { - if x != first { - return None; - } - } - Some(first) -} - -#[cfg(test)] -mod tests { - use p3_baby_bear::BabyBear; - use p3_blake3::Blake3; - use p3_commit::DirectMmcs; - use p3_dft::{Radix2Dit, TwoAdicSubgroupDft}; - use p3_field::extension::BinomialExtensionField; - use p3_field::{AbstractExtensionField, AbstractField}; - use p3_interpolation::{interpolate_coset, interpolate_subgroup}; - use p3_matrix::bitrev::{BitReversableMatrix, BitReversedMatrixView}; - use p3_merkle_tree::FieldMerkleTreeMmcs; - use p3_symmetric::{CompressionFunctionFromHasher, SerializingHasher32}; - use rand::distributions::Standard; - use rand::prelude::Distribution; - use rand::{thread_rng, Rng}; - - use super::*; - - type F = BabyBear; - type F4 = BinomialExtensionField; - type F5 = BinomialExtensionField; - type MyHash = SerializingHasher32; - type MyCompress = CompressionFunctionFromHasher; - type ValMmcs = FieldMerkleTreeMmcs; - - #[test] - fn test_remainder_polys() { - let trace: RowMajorMatrix = RowMajorMatrix::rand(&mut thread_rng(), 32, 5); - let point: F4 = thread_rng().gen(); - let values = interpolate_subgroup(&trace, point); - let rs = Opening::compute_remainder_polys(point, &values); - for (r, y) in rs.into_iter().zip(values) { - // r(alpha) = p(alpha) - assert_eq!( - eval_poly(&r.into_iter().map(F4::from_base).collect_vec(), point), - y - ); - } - } - - fn test_quotient_mmcs_with_sizes>( - num_openings: usize, - trace_sizes: &[(usize, usize)], - ) where - Standard: Distribution, - { - let hash = MyHash::new(Blake3 {}); - let compress = MyCompress::new(hash); - let inner = ValMmcs::new(hash, compress); - - let shift = F::generator(); - - let alphas: Vec = (0..num_openings).map(|_| thread_rng().gen()).collect_vec(); - - let max_height = trace_sizes.iter().map(|&(h, _)| h).max().unwrap(); - let max_height_bits = log2_strict_usize(max_height); - - let (traces, ldes, dims, openings): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) = trace_sizes - .iter() - .map(|&(height, width)| { - let trace = RowMajorMatrix::::rand_nonzero(&mut thread_rng(), height, width); - let lde = Radix2Dit::default() - .coset_lde_batch(trace.clone(), 1, shift) - .bit_reverse_rows(); - let dims = lde.dimensions(); - let lde_truncated = - RowMajorMatrix::new((0..height).flat_map(|r| lde.row(r)).collect_vec(), width); - let openings = alphas - .iter() - .map(|&alpha| { - Opening::new( - alpha, - interpolate_coset( - &BitReversedMatrixView::new(lde_truncated.clone()), - shift, - alpha, - ), - ) - }) - .collect_vec(); - (trace, lde.to_row_major_matrix(), dims, openings) - }) - .multiunzip(); - - let (comm, data) = inner.commit(ldes); - let mmcs = QuotientMmcs:: { - inner, - openings, - coset_shift: shift, - _phantom: PhantomData, - }; - - let index = thread_rng().gen_range(0..max_height); - let (opened_values, proof) = mmcs.open_batch(index, &data); - assert_eq!( - mmcs.verify_batch(&comm, &dims, index, &opened_values, &proof), - Ok(()) - ); - let mut bad_opened_values = opened_values.clone(); - bad_opened_values[0][0] += thread_rng().gen::(); - assert!(mmcs - .verify_batch(&comm, &dims, index, &bad_opened_values, &proof) - .is_err()); - - let mats = mmcs.get_matrices(&data); - for (trace, mat, opened_values_for_mat) in izip!(traces, mats, opened_values) { - let mat = mat.clone().to_row_major_matrix(); - - let height_bits = log2_strict_usize(trace.height()); - let reduced_index = index >> (max_height_bits - height_bits); - - // check that open_batch and get_matrices are consistent - assert_eq!(mat.row_slice(reduced_index), &opened_values_for_mat); - - // check low degree - let poly = - Radix2Dit::default().idft_batch(mat.bit_reverse_rows().to_row_major_matrix()); - let expected_degree = trace.height() - >::D; - assert!((expected_degree..poly.height()).all(|r| poly.row(r).all(|x| x.is_zero()))); - } - } - - #[test] - fn test_quotient_mmcs() { - let sizes: &[&[(usize, usize)]] = &[ - // single matrix - &[(16, 1)], - &[(16, 10)], - &[(16, 14)], - // multi matrix, same size - &[(16, 5), (16, 10)], - &[(8, 10), (8, 5)], - // multi matrix, different size - &[(16, 10), (32, 5)], - &[(32, 52), (8, 30)], - ]; - for num_openings in [1, 2, 3] { - for sizes in sizes { - test_quotient_mmcs_with_sizes::(num_openings, sizes); - // make sure it works when Packing::WIDTH != Extension::D - test_quotient_mmcs_with_sizes::(num_openings, sizes); - } - } - } -} diff --git a/matrix/src/dense.rs b/matrix/src/dense.rs index 9e47ae39..56797285 100644 --- a/matrix/src/dense.rs +++ b/matrix/src/dense.rs @@ -246,6 +246,13 @@ impl<'a, T> RowMajorMatrixView<'a, T> { self.values.chunks_exact(self.width) } + pub fn par_rows(&self) -> impl IndexedParallelIterator + where + T: Sync, + { + self.values.par_chunks_exact(self.width) + } + pub fn split_rows(&self, r: usize) -> (RowMajorMatrixView, RowMajorMatrixView) { let (upper_values, lower_values) = self.values.split_at(r * self.width); let upper = RowMajorMatrixView { diff --git a/maybe-rayon/src/lib.rs b/maybe-rayon/src/lib.rs index dedd7473..ec5c2ddf 100644 --- a/maybe-rayon/src/lib.rs +++ b/maybe-rayon/src/lib.rs @@ -1,5 +1,6 @@ #[cfg(feature = "parallel")] pub mod prelude { + pub use rayon::join; pub use rayon::prelude::*; } diff --git a/maybe-rayon/src/serial.rs b/maybe-rayon/src/serial.rs index aa75ba72..f2dbd6ff 100644 --- a/maybe-rayon/src/serial.rs +++ b/maybe-rayon/src/serial.rs @@ -28,6 +28,7 @@ pub trait IntoParallelRefIterator<'data> { fn par_iter(&'data self) -> Self::Iter; } + impl<'data, I: 'data + ?Sized> IntoParallelRefIterator<'data> for I where &'data I: IntoParallelIterator, @@ -46,6 +47,7 @@ pub trait IntoParallelRefMutIterator<'data> { fn par_iter_mut(&'data mut self) -> Self::Iter; } + impl<'data, I: 'data + ?Sized> IntoParallelRefMutIterator<'data> for I where &'data mut I: IntoParallelIterator, @@ -163,3 +165,13 @@ impl ParIterExt for T { self.flat_map(map_op) } } + +pub fn join(oper_a: A, oper_b: B) -> (RA, RB) +where + A: FnOnce() -> RA, + B: FnOnce() -> RB, +{ + let result_a = oper_a(); + let result_b = oper_b(); + (result_a, result_b) +} diff --git a/merkle-tree/Cargo.toml b/merkle-tree/Cargo.toml index 9055f5c2..759201a3 100644 --- a/merkle-tree/Cargo.toml +++ b/merkle-tree/Cargo.toml @@ -13,7 +13,7 @@ p3-commit = { path = "../commit" } p3-util = { path = "../util" } itertools = "0.12.0" tracing = "0.1.37" -serde = { version = "1.0", features = ["derive"] } +serde = { version = "1.0", default-features = false, features = ["alloc"] } [dev-dependencies] p3-blake3 = { path = "../blake3" } diff --git a/multi-stark/src/config.rs b/multi-stark/src/config.rs index 3a5e00e4..f837e87b 100644 --- a/multi-stark/src/config.rs +++ b/multi-stark/src/config.rs @@ -5,7 +5,7 @@ use p3_commit::MultivariatePcs; use p3_field::{AbstractExtensionField, ExtensionField, Field, PackedField}; use p3_matrix::dense::RowMajorMatrixView; -pub trait StarkConfig { +pub trait StarkGenericConfig { /// A value of the trace. type Val: Field; @@ -28,13 +28,13 @@ pub trait StarkConfig { fn pcs(&self) -> &Self::Pcs; } -pub struct StarkConfigImpl { +pub struct StarkConfig { pcs: Pcs, _phantom: PhantomData<(Val, Challenge, PackedChallenge, Challenger)>, } impl - StarkConfigImpl + StarkConfig { pub fn new(pcs: Pcs) -> Self { Self { @@ -44,8 +44,8 @@ impl } } -impl StarkConfig - for StarkConfigImpl +impl StarkGenericConfig + for StarkConfig where Val: Field, Challenge: ExtensionField, diff --git a/multi-stark/src/prover.rs b/multi-stark/src/prover.rs index 78241533..d4c6513b 100644 --- a/multi-stark/src/prover.rs +++ b/multi-stark/src/prover.rs @@ -3,7 +3,7 @@ use p3_challenger::FieldChallenger; use p3_commit::Pcs; use p3_matrix::dense::RowMajorMatrix; -use crate::{ConstraintFolder, StarkConfig}; +use crate::{ConstraintFolder, StarkGenericConfig}; pub fn prove( config: &SC, @@ -11,7 +11,7 @@ pub fn prove( _challenger: &mut Challenger, trace: RowMajorMatrix, ) where - SC: StarkConfig, + SC: StarkGenericConfig, A: for<'a> Air>, Challenger: FieldChallenger, { diff --git a/multi-stark/tests/mul_air.rs b/multi-stark/tests/mul_air.rs index da56be46..7f3c4eac 100644 --- a/multi-stark/tests/mul_air.rs +++ b/multi-stark/tests/mul_air.rs @@ -1,5 +1,5 @@ // use p3_air::{Air, AirBuilder}; -// use p3_multi_stark::{prove, StarkConfigImpl}; +// use p3_multi_stark::{prove, StarkConfig}; // use p3_challenger::DuplexChallenger; // use p3_fri::FRIBasedPcs; // use p3_lde::NaiveCosetLde; @@ -55,12 +55,12 @@ // // type Mmcs = MerkleTreeMMCS; // type Pcs = TensorPcs; -// type MyConfig = StarkConfigImpl; +// type MyConfig = StarkConfig; // // let mut rng = thread_rng(); // let trace = RowMajorMatrix::rand(&mut rng, 256, 10); // let pcs = todo!(); -// let config = StarkConfigImpl::new(pcs); +// let config = StarkConfig::new(pcs); // let mut challenger = DuplexChallenger::new(perm); // prove::(&MulAir, config, &mut challenger, trace); // } diff --git a/uni-stark/Cargo.toml b/uni-stark/Cargo.toml index 665abd88..475cd3df 100644 --- a/uni-stark/Cargo.toml +++ b/uni-stark/Cargo.toml @@ -15,12 +15,11 @@ p3-maybe-rayon = { path = "../maybe-rayon" } p3-util = { path = "../util" } itertools = "0.12.0" tracing = "0.1.37" -serde = { version = "1.0", default-features = false, features = ["derive"] } +serde = { version = "1.0", default-features = false, features = ["derive", "alloc"] } [dev-dependencies] p3-baby-bear = { path = "../baby-bear" } p3-fri = { path = "../fri" } -p3-ldt = { path = "../ldt" } p3-mds = { path = "../mds" } p3-merkle-tree = { path = "../merkle-tree" } p3-goldilocks = { path = "../goldilocks" } diff --git a/uni-stark/src/check_constraints.rs b/uni-stark/src/check_constraints.rs index 7ba85f11..19718976 100644 --- a/uni-stark/src/check_constraints.rs +++ b/uni-stark/src/check_constraints.rs @@ -2,7 +2,9 @@ use p3_air::{Air, AirBuilder, TwoRowMatrixView}; use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::{Matrix, MatrixRowSlices}; +use tracing::instrument; +#[instrument(name = "check constraints", skip_all)] pub(crate) fn check_constraints(air: &A, main: &RowMajorMatrix) where F: Field, diff --git a/uni-stark/src/config.rs b/uni-stark/src/config.rs index 030d2661..a973ddbb 100644 --- a/uni-stark/src/config.rs +++ b/uni-stark/src/config.rs @@ -5,7 +5,7 @@ use p3_commit::{Pcs, UnivariatePcsWithLde}; use p3_field::{AbstractExtensionField, ExtensionField, PackedField, TwoAdicField}; use p3_matrix::dense::RowMajorMatrix; -pub trait StarkConfig { +pub trait StarkGenericConfig { /// The field over which trace data is encoded. type Val: TwoAdicField; type PackedVal: PackedField; @@ -29,13 +29,13 @@ pub trait StarkConfig { fn pcs(&self) -> &Self::Pcs; } -pub struct StarkConfigImpl { +pub struct StarkConfig { pcs: Pcs, _phantom: PhantomData<(Val, Challenge, PackedChallenge, Challenger)>, } impl - StarkConfigImpl + StarkConfig { pub fn new(pcs: Pcs) -> Self { Self { @@ -45,8 +45,8 @@ impl } } -impl StarkConfig - for StarkConfigImpl +impl StarkGenericConfig + for StarkConfig where Val: TwoAdicField, Challenge: ExtensionField + TwoAdicField, diff --git a/uni-stark/src/decompose.rs b/uni-stark/src/decompose.rs index eed99f28..1858d614 100644 --- a/uni-stark/src/decompose.rs +++ b/uni-stark/src/decompose.rs @@ -1,7 +1,6 @@ use alloc::vec; use alloc::vec::Vec; -use itertools::izip; use p3_field::{AbstractExtensionField, TwoAdicField}; use p3_matrix::dense::RowMajorMatrix; use p3_maybe_rayon::prelude::*; @@ -54,8 +53,8 @@ fn decompose(poly: Vec, shift: F, log_chunks: usize) -> Vec< let half_n = poly.len() / 2; let g_inv = F::two_adic_generator(log_n).inverse(); - let mut even = Vec::with_capacity(half_n); - let mut odd = Vec::with_capacity(half_n); + let one_half = F::two().inverse(); + let (first, second) = poly.split_at(half_n); // Note that // p_e(g^(2i)) = (p(g^i) + p(g^(n/2 + i))) / 2 @@ -63,17 +62,28 @@ fn decompose(poly: Vec, shift: F, log_chunks: usize) -> Vec< // p_e(g^(2i)) = (a + b) / 2 // p_o(g^(2i)) = (a - b) / (2 s g^i) - let one_half = F::two().inverse(); - let (first, second) = poly.split_at(half_n); - for (g_inv_power, &a, &b) in izip!(g_inv.shifted_powers(shift.inverse()), first, second) { - let sum = a + b; - let diff = a - b; - even.push(sum * one_half); - odd.push(diff * one_half * g_inv_power); - } + let mut g_inv_powers = g_inv.shifted_powers(shift.inverse()); + let g_inv_powers = (0..first.len()) + .map(|_| g_inv_powers.next().unwrap()) + .collect::>(); + let (even, odd): (Vec<_>, Vec<_>) = first + .par_iter() + .zip(second.par_iter()) + .zip(g_inv_powers.par_iter()) + .map(|((&a, &b), g_inv_power)| { + let sum = a + b; + let diff = a - b; + (sum * one_half, diff * one_half * *g_inv_power) + }) + .unzip(); + + let (even_decomp, odd_decomp) = join( + || decompose(even, shift.square(), log_chunks - 1), + || decompose(odd, shift.square(), log_chunks - 1), + ); - let mut combined = decompose(even, shift.square(), log_chunks - 1); - combined.extend(decompose(odd, shift.square(), log_chunks - 1)); + let mut combined = even_decomp; + combined.extend(odd_decomp); combined } diff --git a/uni-stark/src/folder.rs b/uni-stark/src/folder.rs index c9b48791..8a5de17f 100644 --- a/uni-stark/src/folder.rs +++ b/uni-stark/src/folder.rs @@ -1,9 +1,9 @@ use p3_air::{AirBuilder, TwoRowMatrixView}; use p3_field::{AbstractField, Field}; -use crate::StarkConfig; +use crate::StarkGenericConfig; -pub struct ProverConstraintFolder<'a, SC: StarkConfig> { +pub struct ProverConstraintFolder<'a, SC: StarkGenericConfig> { pub main: TwoRowMatrixView<'a, SC::PackedVal>, pub is_first_row: SC::PackedVal, pub is_last_row: SC::PackedVal, @@ -21,7 +21,7 @@ pub struct VerifierConstraintFolder<'a, Challenge> { pub accumulator: Challenge, } -impl<'a, SC: StarkConfig> AirBuilder for ProverConstraintFolder<'a, SC> { +impl<'a, SC: StarkGenericConfig> AirBuilder for ProverConstraintFolder<'a, SC> { type F = SC::Val; type Expr = SC::PackedVal; type Var = SC::PackedVal; diff --git a/uni-stark/src/proof.rs b/uni-stark/src/proof.rs index 1bd3e576..b03a00b0 100644 --- a/uni-stark/src/proof.rs +++ b/uni-stark/src/proof.rs @@ -4,16 +4,16 @@ use p3_commit::Pcs; use p3_matrix::dense::RowMajorMatrix; use serde::{Deserialize, Serialize}; -use crate::StarkConfig; +use crate::StarkGenericConfig; -type Val = ::Val; +type Val = ::Val; type ValMat = RowMajorMatrix>; -type Com = <::Pcs as Pcs, ValMat>>::Commitment; -type PcsProof = <::Pcs as Pcs, ValMat>>::Proof; +type Com = <::Pcs as Pcs, ValMat>>::Commitment; +type PcsProof = <::Pcs as Pcs, ValMat>>::Proof; #[derive(Serialize, Deserialize)] #[serde(bound = "")] -pub struct Proof { +pub struct Proof { pub(crate) commitments: Commitments>, pub(crate) opened_values: OpenedValues, pub(crate) opening_proof: PcsProof, diff --git a/uni-stark/src/prover.rs b/uni-stark/src/prover.rs index 6b079ba2..5c399224 100644 --- a/uni-stark/src/prover.rs +++ b/uni-stark/src/prover.rs @@ -17,8 +17,8 @@ use tracing::{info_span, instrument}; use crate::symbolic_builder::{get_log_quotient_degree, SymbolicAirBuilder}; use crate::{ - decompose_and_flatten, Commitments, OpenedValues, Proof, ProverConstraintFolder, StarkConfig, - ZerofierOnCoset, + decompose_and_flatten, Commitments, OpenedValues, Proof, ProverConstraintFolder, + StarkGenericConfig, ZerofierOnCoset, }; #[instrument(skip_all)] @@ -29,7 +29,7 @@ pub fn prove Air, ) -> Proof where - SC: StarkConfig, + SC: StarkGenericConfig, A: Air> + for<'a> Air>, { // #[cfg(debug_assertions)] @@ -120,7 +120,7 @@ fn quotient_values( alpha: SC::Challenge, ) -> Vec where - SC: StarkConfig, + SC: StarkGenericConfig, A: for<'a> Air>, Mat: MatrixGet + Sync, { diff --git a/uni-stark/src/symbolic_builder.rs b/uni-stark/src/symbolic_builder.rs index 471cbfc6..19ce5cec 100644 --- a/uni-stark/src/symbolic_builder.rs +++ b/uni-stark/src/symbolic_builder.rs @@ -6,10 +6,12 @@ use p3_air::{Air, AirBuilder}; use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; use p3_util::log2_ceil_usize; +use tracing::instrument; use crate::symbolic_expression::SymbolicExpression; use crate::symbolic_variable::SymbolicVariable; +#[instrument(name = "infer log of constraint degree", skip_all)] pub fn get_log_quotient_degree(air: &A) -> usize where F: Field, @@ -24,6 +26,7 @@ where log2_ceil_usize(constraint_degree - 1) } +#[instrument(name = "infer constraint degree", skip_all, level = "debug")] pub fn get_max_constraint_degree(air: &A) -> usize where F: Field, @@ -36,6 +39,7 @@ where .unwrap_or(0) } +#[instrument(name = "evalute constraints symbolically", skip_all, level = "debug")] pub fn get_symbolic_constraints(air: &A) -> Vec> where F: Field, diff --git a/uni-stark/src/symbolic_expression.rs b/uni-stark/src/symbolic_expression.rs index e0ba03a8..03bbb614 100644 --- a/uni-stark/src/symbolic_expression.rs +++ b/uni-stark/src/symbolic_expression.rs @@ -15,10 +15,25 @@ pub enum SymbolicExpression { IsLastRow, IsTransition, Constant(F), - Add(Rc, Rc), - Sub(Rc, Rc), - Neg(Rc), - Mul(Rc, Rc), + Add { + x: Rc, + y: Rc, + degree_multiple: usize, + }, + Sub { + x: Rc, + y: Rc, + degree_multiple: usize, + }, + Neg { + x: Rc, + degree_multiple: usize, + }, + Mul { + x: Rc, + y: Rc, + degree_multiple: usize, + }, } impl SymbolicExpression { @@ -30,10 +45,18 @@ impl SymbolicExpression { SymbolicExpression::IsLastRow => 1, SymbolicExpression::IsTransition => 0, SymbolicExpression::Constant(_) => 0, - SymbolicExpression::Add(x, y) => x.degree_multiple().max(y.degree_multiple()), - SymbolicExpression::Sub(x, y) => x.degree_multiple().max(y.degree_multiple()), - SymbolicExpression::Neg(x) => x.degree_multiple(), - SymbolicExpression::Mul(x, y) => x.degree_multiple() + y.degree_multiple(), + SymbolicExpression::Add { + degree_multiple, .. + } => *degree_multiple, + SymbolicExpression::Sub { + degree_multiple, .. + } => *degree_multiple, + SymbolicExpression::Neg { + degree_multiple, .. + } => *degree_multiple, + SymbolicExpression::Mul { + degree_multiple, .. + } => *degree_multiple, } } } @@ -112,7 +135,12 @@ impl Add for SymbolicExpression { type Output = Self; fn add(self, rhs: Self) -> Self { - Self::Add(Rc::new(self), Rc::new(rhs)) + let degree_multiple = self.degree_multiple().max(rhs.degree_multiple()); + Self::Add { + x: Rc::new(self), + y: Rc::new(rhs), + degree_multiple, + } } } @@ -152,7 +180,12 @@ impl Sub for SymbolicExpression { type Output = Self; fn sub(self, rhs: Self) -> Self { - Self::Sub(Rc::new(self), Rc::new(rhs)) + let degree_multiple = self.degree_multiple().max(rhs.degree_multiple()); + Self::Sub { + x: Rc::new(self), + y: Rc::new(rhs), + degree_multiple, + } } } @@ -180,7 +213,11 @@ impl Neg for SymbolicExpression { type Output = Self; fn neg(self) -> Self { - Self::Neg(Rc::new(self)) + let degree_multiple = self.degree_multiple(); + Self::Neg { + x: Rc::new(self), + degree_multiple, + } } } @@ -188,7 +225,13 @@ impl Mul for SymbolicExpression { type Output = Self; fn mul(self, rhs: Self) -> Self { - Self::Mul(Rc::new(self), Rc::new(rhs)) + #[allow(clippy::suspicious_arithmetic_impl)] + let degree_multiple = self.degree_multiple() + rhs.degree_multiple(); + Self::Mul { + x: Rc::new(self), + y: Rc::new(rhs), + degree_multiple, + } } } @@ -232,10 +275,10 @@ impl Display for SymbolicExpression { SymbolicExpression::IsLastRow => write!(f, "IsLastRow"), SymbolicExpression::IsTransition => write!(f, "IsTransition"), SymbolicExpression::Constant(c) => write!(f, "{}", c), - SymbolicExpression::Add(x, y) => write!(f, "({} + {})", x, y), - SymbolicExpression::Sub(x, y) => write!(f, "({} - {})", x, y), - SymbolicExpression::Neg(x) => write!(f, "-{}", x), - SymbolicExpression::Mul(x, y) => write!(f, "({} * {})", x, y), + SymbolicExpression::Add { x, y, .. } => write!(f, "({} + {})", x, y), + SymbolicExpression::Sub { x, y, .. } => write!(f, "({} - {})", x, y), + SymbolicExpression::Neg { x, .. } => write!(f, "-{}", x), + SymbolicExpression::Mul { x, y, .. } => write!(f, "({} * {})", x, y), } } } diff --git a/uni-stark/src/verifier.rs b/uni-stark/src/verifier.rs index 4295c47e..96944aa2 100644 --- a/uni-stark/src/verifier.rs +++ b/uni-stark/src/verifier.rs @@ -7,10 +7,12 @@ use p3_commit::UnivariatePcs; use p3_field::{AbstractExtensionField, AbstractField, Field, TwoAdicField}; use p3_matrix::Dimensions; use p3_util::reverse_slice_index_bits; +use tracing::instrument; use crate::symbolic_builder::{get_log_quotient_degree, SymbolicAirBuilder}; -use crate::{Proof, StarkConfig, VerifierConstraintFolder}; +use crate::{Proof, StarkGenericConfig, VerifierConstraintFolder}; +#[instrument(skip_all)] pub fn verify( config: &SC, air: &A, @@ -18,7 +20,7 @@ pub fn verify( proof: &Proof, ) -> Result<(), VerificationError> where - SC: StarkConfig, + SC: StarkGenericConfig, A: Air> + for<'a> Air>, { let log_quotient_degree = get_log_quotient_degree::(air); diff --git a/uni-stark/tests/mul_air.rs b/uni-stark/tests/mul_air.rs index f50a9533..e14db7b0 100644 --- a/uni-stark/tests/mul_air.rs +++ b/uni-stark/tests/mul_air.rs @@ -6,14 +6,13 @@ use p3_commit::ExtensionMmcs; use p3_dft::Radix2DitParallel; use p3_field::extension::BinomialExtensionField; use p3_field::Field; -use p3_fri::{FriBasedPcs, FriConfigImpl, FriLdt}; -use p3_ldt::QuotientMmcs; +use p3_fri::{FriConfig, TwoAdicFriPcs, TwoAdicFriPcsConfig}; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::MatrixRowSlices; use p3_merkle_tree::FieldMerkleTreeMmcs; use p3_poseidon2::{DiffusionMatrixBabybear, Poseidon2}; use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; -use p3_uni_stark::{prove, verify, StarkConfigImpl, VerificationError}; +use p3_uni_stark::{prove, verify, StarkConfig, VerificationError}; use rand::distributions::{Distribution, Standard}; use rand::{thread_rng, Rng}; use tracing_forest::ForestLayer; @@ -96,16 +95,19 @@ fn test_prove_baby_bear() -> Result<(), VerificationError> { type Challenger = DuplexChallenger; - type Quotient = QuotientMmcs; - type MyFriConfig = FriConfigImpl; - let fri_config = MyFriConfig::new(1, 40, 8, challenge_mmcs); - let ldt = FriLdt { config: fri_config }; + let fri_config = FriConfig { + log_blowup: 1, + num_queries: 40, + proof_of_work_bits: 8, + mmcs: challenge_mmcs, + }; + type Pcs = + TwoAdicFriPcs>; + let pcs = Pcs::new(fri_config, dft, val_mmcs); - type Pcs = FriBasedPcs; - type MyConfig = StarkConfigImpl; + type MyConfig = StarkConfig; + let config = StarkConfig::new(pcs); - let pcs = Pcs::new(dft, val_mmcs, ldt); - let config = StarkConfigImpl::new(pcs); let mut challenger = Challenger::new(perm.clone()); let trace = random_valid_trace::(HEIGHT); let proof = prove::(&config, &MulAir, &mut challenger, trace); diff --git a/util/src/lib.rs b/util/src/lib.rs index 9e7ae2e1..d7c28cd2 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -7,6 +7,7 @@ extern crate alloc; use core::hint::unreachable_unchecked; pub mod array_serialization; +pub mod linear_map; /// Computes `ceil(a / b)`. Assumes `a + b` does not overflow. #[must_use] @@ -120,3 +121,22 @@ pub fn branch_hint() { core::arch::asm!("", options(nomem, nostack, preserves_flags)); } } + +/// Convenience methods for Vec. +pub trait VecExt { + /// Push `elem` and return a reference to it. + fn pushed_ref(&mut self, elem: T) -> &T; + /// Push `elem` and return a mutable reference to it. + fn pushed_mut(&mut self, elem: T) -> &mut T; +} + +impl VecExt for alloc::vec::Vec { + fn pushed_ref(&mut self, elem: T) -> &T { + self.push(elem); + self.last().unwrap() + } + fn pushed_mut(&mut self, elem: T) -> &mut T { + self.push(elem); + self.last_mut().unwrap() + } +} diff --git a/util/src/linear_map.rs b/util/src/linear_map.rs new file mode 100644 index 00000000..165a1289 --- /dev/null +++ b/util/src/linear_map.rs @@ -0,0 +1,65 @@ +use alloc::vec::Vec; +use core::mem; + +use crate::VecExt; + +/// O(n) Vec-backed map for keys that only implement Eq. +/// Only use this for a very small number of keys. +pub struct LinearMap(Vec<(K, V)>); + +impl Default for LinearMap { + fn default() -> Self { + Self(Default::default()) + } +} + +impl LinearMap { + pub fn new() -> Self { + Default::default() + } + pub fn get(&self, k: &K) -> Option<&V> { + self.0.iter().find(|(kk, _)| kk == k).map(|(_, v)| v) + } + pub fn get_mut(&mut self, k: &K) -> Option<&mut V> { + self.0.iter_mut().find(|(kk, _)| kk == k).map(|(_, v)| v) + } + pub fn insert(&mut self, k: K, mut v: V) -> Option { + if let Some(vv) = self.get_mut(&k) { + mem::swap(&mut v, vv); + Some(v) + } else { + self.0.push((k, v)); + None + } + } + pub fn get_or_insert_with(&mut self, k: K, f: impl FnOnce() -> V) -> &mut V { + let existing = self.0.iter().position(|(kk, _)| kk == &k); + if let Some(idx) = existing { + &mut self.0[idx].1 + } else { + let slot = self.0.pushed_mut((k, f())); + &mut slot.1 + } + } + pub fn values(&self) -> impl Iterator { + self.0.iter().map(|(_, v)| v) + } +} + +impl FromIterator<(K, V)> for LinearMap { + fn from_iter>(iter: T) -> Self { + let mut me = LinearMap::default(); + for (k, v) in iter { + me.insert(k, v); + } + me + } +} + +impl IntoIterator for LinearMap { + type Item = (K, V); + type IntoIter = as IntoIterator>::IntoIter; + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +}