diff --git a/air/src/virtual_column.rs b/air/src/virtual_column.rs index 0ccc9835..899207fa 100644 --- a/air/src/virtual_column.rs +++ b/air/src/virtual_column.rs @@ -1,3 +1,4 @@ +use alloc::borrow::Cow; use alloc::vec; use alloc::vec::Vec; use core::ops::Mul; @@ -6,8 +7,8 @@ use p3_field::{AbstractField, Field}; /// An affine function over columns in a PAIR. #[derive(Clone, Debug)] -pub struct VirtualPairCol { - column_weights: Vec<(PairCol, F)>, +pub struct VirtualPairCol<'a, F: Field> { + column_weights: Cow<'a, [(PairCol, F)]>, constant: F, } @@ -19,7 +20,7 @@ pub enum PairCol { } impl PairCol { - fn get(&self, preprocessed: &[T], main: &[T]) -> T { + pub const fn get(&self, preprocessed: &[T], main: &[T]) -> T { match self { PairCol::Preprocessed(i) => preprocessed[*i], PairCol::Main(i) => main[*i], @@ -27,14 +28,28 @@ impl PairCol { } } -impl VirtualPairCol { - pub fn new(column_weights: Vec<(PairCol, F)>, constant: F) -> Self { +impl<'a, F: Field> VirtualPairCol<'a, F> { + pub const fn new(column_weights: Cow<'a, [(PairCol, F)]>, constant: F) -> Self { Self { column_weights, constant, } } + pub const fn new_owned(column_weights: Vec<(PairCol, F)>, constant: F) -> Self { + Self { + column_weights: Cow::Owned(column_weights), + constant, + } + } + + pub const fn new_borrowed(column_weights: &'a [(PairCol, F)], constant: F) -> Self { + Self { + column_weights: Cow::Borrowed(column_weights), + constant, + } + } + pub fn new_preprocessed(column_weights: Vec<(usize, F)>, constant: F) -> Self { Self::new( column_weights @@ -55,6 +70,14 @@ impl VirtualPairCol { ) } + pub fn get_column_weights(&self) -> &[(PairCol, F)] { + &self.column_weights + } + + pub const fn get_constant(&self) -> F { + self.constant + } + #[must_use] pub fn one() -> Self { Self::constant(F::one()) @@ -63,7 +86,7 @@ impl VirtualPairCol { #[must_use] pub fn constant(x: F) -> Self { Self { - column_weights: vec![], + column_weights: Cow::Owned(vec![]), constant: x, } } @@ -71,7 +94,7 @@ impl VirtualPairCol { #[must_use] pub fn single(column: PairCol) -> Self { Self { - column_weights: vec![(column, F::one())], + column_weights: Cow::Owned(vec![(column, F::one())]), constant: F::zero(), } } @@ -117,7 +140,7 @@ impl VirtualPairCol { Var: Into + Copy, { let mut result = self.constant.into(); - for (column, weight) in &self.column_weights { + for (column, weight) in self.column_weights.iter() { result += column.get(preprocessed, main).into() * *weight; } result diff --git a/baby-bear/src/baby_bear.rs b/baby-bear/src/baby_bear.rs index 88f03180..30981483 100644 --- a/baby-bear/src/baby_bear.rs +++ b/baby-bear/src/baby_bear.rs @@ -3,8 +3,8 @@ use core::iter::{Product, Sum}; use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; use p3_field::{ - exp_1725656503, exp_u64_by_squaring, AbstractField, Field, Packable, PrimeField, PrimeField32, - PrimeField64, TwoAdicField, + exp_1725656503, exp_u64_by_squaring, halve_u32, AbstractField, Field, Packable, PrimeField, + PrimeField32, PrimeField64, TwoAdicField, }; use rand::distributions::{Distribution, Standard}; use rand::Rng; @@ -45,7 +45,7 @@ pub struct BabyBear { impl BabyBear { /// create a new `BabyBear` from a canonical `u32`. #[inline] - pub(crate) const fn new(n: u32) -> Self { + pub const fn new(n: u32) -> Self { Self { value: to_monty(n) } } } @@ -238,6 +238,13 @@ impl Field for BabyBear { Some(p1110111111111111111111111111111) } + + #[inline] + fn halve(&self) -> Self { + BabyBear { + value: halve_u32::

(self.value), + } + } } impl PrimeField for BabyBear {} diff --git a/commit/src/adapters/extension_mmcs.rs b/commit/src/adapters/extension_mmcs.rs index 9377586a..319eb431 100644 --- a/commit/src/adapters/extension_mmcs.rs +++ b/commit/src/adapters/extension_mmcs.rs @@ -14,7 +14,7 @@ pub struct ExtensionMmcs { } impl ExtensionMmcs { - pub fn new(inner: InnerMmcs) -> Self { + pub const fn new(inner: InnerMmcs) -> Self { Self { inner, _phantom: PhantomData, diff --git a/field-testing/src/lib.rs b/field-testing/src/lib.rs index 107c0db4..0b6959d4 100644 --- a/field-testing/src/lib.rs +++ b/field-testing/src/lib.rs @@ -26,6 +26,7 @@ where assert_eq!(x + (-x), F::zero()); assert_eq!(-x, F::zero() - x); assert_eq!(x + x, x * F::two()); + assert_eq!(x, x.halve() * F::two()); assert_eq!(x * (-x), -x.square()); assert_eq!(x + y, y + x); assert_eq!(x * y, y * x); diff --git a/field/src/extension/binomial_extension.rs b/field/src/extension/binomial_extension.rs index b773ad82..d44cfb30 100644 --- a/field/src/extension/binomial_extension.rs +++ b/field/src/extension/binomial_extension.rs @@ -218,6 +218,12 @@ impl, const D: usize> Field for BinomialExtensionFiel _ => Some(self.frobenius_inv()), } } + + fn halve(&self) -> Self { + Self { + value: self.value.map(|x| x.halve()), + } + } } impl Display for BinomialExtensionField diff --git a/field/src/field.rs b/field/src/field.rs index d547b55a..72a374da 100644 --- a/field/src/field.rs +++ b/field/src/field.rs @@ -207,6 +207,17 @@ pub trait Field: fn inverse(&self) -> Self { self.try_inverse().expect("Tried to invert zero") } + + /// Computes input/2. + /// Should be overwritten by most field implementations to use bitshifts. + /// Will error if the field characteristic is 2. + #[must_use] + fn halve(&self) -> Self { + let half = Self::two() + .try_inverse() + .expect("Cannot divide by 2 in fields with characteristic 2"); + *self * half + } } pub trait PrimeField: Field + Ord {} diff --git a/field/src/helpers.rs b/field/src/helpers.rs index 3167c506..1d0747f2 100644 --- a/field/src/helpers.rs +++ b/field/src/helpers.rs @@ -99,3 +99,31 @@ pub fn eval_poly(poly: &[AF], x: AF) -> AF { } acc } + +/// Given an element x from a 32 bit field F_P compute x/2. +#[inline] +pub fn halve_u32(input: u32) -> u32 { + let shift = (P + 1) >> 1; + let shr = input >> 1; + let lo_bit = input & 1; + let shr_corr = shr + shift; + if lo_bit == 0 { + shr + } else { + shr_corr + } +} + +/// Given an element x from a 64 bit field F_P compute x/2. +#[inline] +pub fn halve_u64(input: u64) -> u64 { + let shift = (P + 1) >> 1; + let shr = input >> 1; + let lo_bit = input & 1; + let shr_corr = shr + shift; + if lo_bit == 0 { + shr + } else { + shr_corr + } +} diff --git a/fri/src/two_adic_pcs.rs b/fri/src/two_adic_pcs.rs index 1845bef0..5eeea655 100644 --- a/fri/src/two_adic_pcs.rs +++ b/fri/src/two_adic_pcs.rs @@ -77,7 +77,7 @@ pub struct TwoAdicFriPcs { } impl TwoAdicFriPcs { - pub fn new(fri: FriConfig, dft: C::Dft, mmcs: C::InputMmcs) -> Self { + pub const fn new(fri: FriConfig, dft: C::Dft, mmcs: C::InputMmcs) -> Self { Self { fri, dft, mmcs } } } diff --git a/goldilocks/src/lib.rs b/goldilocks/src/lib.rs index e5f6557b..c53c01c7 100644 --- a/goldilocks/src/lib.rs +++ b/goldilocks/src/lib.rs @@ -11,14 +11,17 @@ use core::iter::{Product, Sum}; use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; use p3_field::{ - exp_10540996611094048183, exp_u64_by_squaring, AbstractField, Field, Packable, PrimeField, - PrimeField64, TwoAdicField, + exp_10540996611094048183, exp_u64_by_squaring, halve_u64, AbstractField, Field, Packable, + PrimeField, PrimeField64, TwoAdicField, }; use p3_util::{assume, branch_hint}; use rand::distributions::{Distribution, Standard}; use rand::Rng; use serde::{Deserialize, Serialize}; +/// The Goldilocks prime +const P: u64 = 0xFFFF_FFFF_0000_0001; + /// The prime field known as Goldilocks, defined as `F_p` where `p = 2^64 - 2^32 + 1`. #[derive(Copy, Clone, Default, Serialize, Deserialize)] pub struct Goldilocks { @@ -210,12 +213,17 @@ impl Field for Goldilocks { // compute base^1111111111111111111111111111111011111111111111111111111111111111 Some(t63.square() * *self) } + + #[inline] + fn halve(&self) -> Self { + Goldilocks::new(halve_u64::

(self.value)) + } } impl PrimeField for Goldilocks {} impl PrimeField64 for Goldilocks { - const ORDER_U64: u64 = 0xFFFF_FFFF_0000_0001; + const ORDER_U64: u64 = P; #[inline] fn as_canonical_u64(&self) -> u64 { diff --git a/keccak-air/Cargo.toml b/keccak-air/Cargo.toml index c1b5f97d..5d4d95b5 100644 --- a/keccak-air/Cargo.toml +++ b/keccak-air/Cargo.toml @@ -22,6 +22,7 @@ p3-keccak = { path = "../keccak" } p3-maybe-rayon = { path = "../maybe-rayon" } p3-mds = { path = "../mds" } p3-merkle-tree = { path = "../merkle-tree" } +p3-poseidon = {path = "../poseidon"} p3-poseidon2 = { path = "../poseidon2" } p3-symmetric = { path = "../symmetric" } p3-uni-stark = { path = "../uni-stark" } @@ -38,6 +39,9 @@ name = "prove_baby_bear_poseidon2" [[example]] name = "prove_goldilocks_keccak" +[[example]] +name = "prove_goldilocks_poseidon" + [features] # TODO: Consider removing, at least when this gets split off into another repository. # We should be able to enable p3-maybe-rayon/parallel directly; this just doesn't diff --git a/keccak-air/examples/prove_goldilocks_poseidon.rs b/keccak-air/examples/prove_goldilocks_poseidon.rs new file mode 100644 index 00000000..395eb01b --- /dev/null +++ b/keccak-air/examples/prove_goldilocks_poseidon.rs @@ -0,0 +1,83 @@ +use p3_challenger::DuplexChallenger; +use p3_commit::ExtensionMmcs; +use p3_dft::Radix2DitParallel; +use p3_field::extension::BinomialExtensionField; +use p3_field::Field; +use p3_fri::{FriConfig, TwoAdicFriPcs, TwoAdicFriPcsConfig}; +use p3_goldilocks::Goldilocks; +use p3_keccak_air::{generate_trace_rows, KeccakAir}; +use p3_mds::goldilocks::MdsMatrixGoldilocks; +use p3_merkle_tree::FieldMerkleTreeMmcs; +use p3_poseidon::Poseidon; +use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; +use p3_uni_stark::{prove, verify, StarkConfig, VerificationError}; +use rand::{random, thread_rng}; +use tracing_forest::util::LevelFilter; +use tracing_forest::ForestLayer; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{EnvFilter, Registry}; + +const NUM_HASHES: usize = 680; + +fn main() -> Result<(), VerificationError> { + let env_filter = EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(); + + Registry::default() + .with(env_filter) + .with(ForestLayer::default()) + .init(); + + type Val = Goldilocks; + type Challenge = BinomialExtensionField; + + type Perm = Poseidon; + let perm = Perm::new_from_rng(4, 22, MdsMatrixGoldilocks, &mut thread_rng()); + + type MyHash = PaddingFreeSponge; + let hash = MyHash::new(perm.clone()); + + type MyCompress = TruncatedPermutation; + let compress = MyCompress::new(perm.clone()); + + type ValMmcs = FieldMerkleTreeMmcs< + ::Packing, + ::Packing, + MyHash, + MyCompress, + 4, + >; + let val_mmcs = ValMmcs::new(hash, compress); + + type ChallengeMmcs = ExtensionMmcs; + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + + type Dft = Radix2DitParallel; + let dft = Dft {}; + + type Challenger = DuplexChallenger; + + let fri_config = FriConfig { + log_blowup: 1, + num_queries: 100, + proof_of_work_bits: 16, + mmcs: challenge_mmcs, + }; + type Pcs = + TwoAdicFriPcs>; + let pcs = Pcs::new(fri_config, dft, val_mmcs); + + type MyConfig = StarkConfig; + let config = StarkConfig::new(pcs); + + let mut challenger = Challenger::new(perm.clone()); + + let inputs = (0..NUM_HASHES).map(|_| random()).collect::>(); + let trace = generate_trace_rows::(inputs); + let proof = prove::(&config, &KeccakAir {}, &mut challenger, trace); + + let mut challenger = Challenger::new(perm); + verify(&config, &KeccakAir {}, &mut challenger, &proof) +} diff --git a/merkle-tree/src/mmcs.rs b/merkle-tree/src/mmcs.rs index b3ad012a..7556ccc1 100644 --- a/merkle-tree/src/mmcs.rs +++ b/merkle-tree/src/mmcs.rs @@ -27,7 +27,7 @@ pub struct FieldMerkleTreeMmcs { } impl FieldMerkleTreeMmcs { - pub fn new(hash: H, compress: C) -> Self { + pub const fn new(hash: H, compress: C) -> Self { Self { hash, compress, diff --git a/mersenne-31/src/mersenne_31.rs b/mersenne-31/src/mersenne_31.rs index febed03a..e973f0e6 100644 --- a/mersenne-31/src/mersenne_31.rs +++ b/mersenne-31/src/mersenne_31.rs @@ -5,13 +5,16 @@ use core::iter::{Product, Sum}; use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; use p3_field::{ - exp_1717986917, exp_u64_by_squaring, AbstractField, Field, Packable, PrimeField, PrimeField32, - PrimeField64, + exp_1717986917, exp_u64_by_squaring, halve_u32, AbstractField, Field, Packable, PrimeField, + PrimeField32, PrimeField64, }; use rand::distributions::{Distribution, Standard}; use rand::Rng; use serde::{Deserialize, Serialize}; +/// The Mersenne31 prime +const P: u32 = (1 << 31) - 1; + /// The prime field `F_p` where `p = 2^31 - 1`. #[derive(Copy, Clone, Default, Serialize, Deserialize)] pub struct Mersenne31 { @@ -225,12 +228,17 @@ impl Field for Mersenne31 { p1111111111111111111111111111.exp_power_of_2(3) * p101; Some(p1111111111111111111111111111101) } + + #[inline] + fn halve(&self) -> Self { + Mersenne31::new(halve_u32::

(self.value)) + } } impl PrimeField for Mersenne31 {} impl PrimeField32 for Mersenne31 { - const ORDER_U32: u32 = (1 << 31) - 1; + const ORDER_U32: u32 = P; #[inline] fn as_canonical_u32(&self) -> u32 { diff --git a/symmetric/src/compression.rs b/symmetric/src/compression.rs index 1a3ad7a0..3347a7ca 100644 --- a/symmetric/src/compression.rs +++ b/symmetric/src/compression.rs @@ -58,7 +58,7 @@ where T: Clone, H: CryptographicHasher, { - pub fn new(hasher: H) -> Self { + pub const fn new(hasher: H) -> Self { Self { hasher, _phantom: PhantomData, diff --git a/symmetric/src/serializing_hasher.rs b/symmetric/src/serializing_hasher.rs index d2e88a22..61e87c57 100644 --- a/symmetric/src/serializing_hasher.rs +++ b/symmetric/src/serializing_hasher.rs @@ -15,13 +15,13 @@ pub struct SerializingHasher64 { } impl SerializingHasher32 { - pub fn new(inner: Inner) -> Self { + pub const fn new(inner: Inner) -> Self { Self { inner } } } impl SerializingHasher64 { - pub fn new(inner: Inner) -> Self { + pub const fn new(inner: Inner) -> Self { Self { inner } } } diff --git a/uni-stark/src/decompose.rs b/uni-stark/src/decompose.rs index 1858d614..c6a3a9a4 100644 --- a/uni-stark/src/decompose.rs +++ b/uni-stark/src/decompose.rs @@ -53,7 +53,6 @@ fn decompose(poly: Vec, shift: F, log_chunks: usize) -> Vec< let half_n = poly.len() / 2; let g_inv = F::two_adic_generator(log_n).inverse(); - let one_half = F::two().inverse(); let (first, second) = poly.split_at(half_n); // Note that @@ -73,7 +72,7 @@ fn decompose(poly: Vec, shift: F, log_chunks: usize) -> Vec< .map(|((&a, &b), g_inv_power)| { let sum = a + b; let diff = a - b; - (sum * one_half, diff * one_half * *g_inv_power) + (sum.halve(), diff.halve() * *g_inv_power) }) .unzip();