diff --git a/.clippy.toml b/.clippy.toml index 9ed6287e0..800113fde 100644 --- a/.clippy.toml +++ b/.clippy.toml @@ -8,4 +8,4 @@ disallowed-methods = [ { path = "std::vec::Vec::leak", reason = "Not running the destructors on futures created inside seq_join module will cause UB in IPA. Make sure you don't leak any of those." }, ] -future-size-threshold = 8192 \ No newline at end of file +future-size-threshold = 10240 diff --git a/ipa-core/src/helpers/hashing.rs b/ipa-core/src/helpers/hashing.rs index 10f50484b..ae9579097 100644 --- a/ipa-core/src/helpers/hashing.rs +++ b/ipa-core/src/helpers/hashing.rs @@ -12,7 +12,7 @@ use crate::{ protocol::prss::FromRandomU128, }; -#[derive(Debug, PartialEq)] +#[derive(Clone, Debug, Default, PartialEq)] pub struct Hash(Output); impl Serializable for Hash { diff --git a/ipa-core/src/helpers/mod.rs b/ipa-core/src/helpers/mod.rs index 2c43ccd53..1c49e7de2 100644 --- a/ipa-core/src/helpers/mod.rs +++ b/ipa-core/src/helpers/mod.rs @@ -7,6 +7,7 @@ use std::{ convert::Infallible, fmt::{Debug, Display, Formatter}, num::NonZeroUsize, + ops::Not, }; use generic_array::GenericArray; @@ -267,6 +268,17 @@ pub enum Direction { Right, } +impl Not for Direction { + type Output = Self; + + fn not(self) -> Self { + match self { + Direction::Left => Direction::Right, + Direction::Right => Direction::Left, + } + } +} + impl Role { const H1_STR: &'static str = "H1"; const H2_STR: &'static str = "H2"; diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index 671dfa08d..951b5b4c4 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -13,7 +13,6 @@ use crate::{ context::{ dzkp_validator::{Batch, MaliciousDZKPValidatorInner, Segment}, prss::InstrumentedIndexedSharedRandomness, - step::DzkpBatchStep, Context as ContextTrait, DZKPContext, InstrumentedSequentialSharedRandomness, MaliciousContext, }, @@ -100,9 +99,7 @@ impl<'a, B: ShardBinding> DZKPContext for DZKPUpgraded<'a, B> { .batcher .lock() .unwrap() - .validate_record(record_id, |batch_idx, batch| { - batch.validate(ctx.narrow(&DzkpBatchStep(batch_idx))) - }); + .validate_record(record_id, |batch_idx, batch| batch.validate(ctx, batch_idx)); validation_future.await } diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index d21f1d1c3..3ed370cf5 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -15,11 +15,14 @@ use crate::{ dzkp_field::{DZKPBaseField, UVTupleBlock}, dzkp_malicious::DZKPUpgraded as MaliciousDZKPUpgraded, dzkp_semi_honest::DZKPUpgraded as SemiHonestDZKPUpgraded, - step::{DzkpSingleBatchStep, DzkpValidationProtocolStep as Step}, + step::DzkpValidationProtocolStep as Step, Base, Context, DZKPContext, MaliciousContext, MaliciousProtocolSteps, }, - ipa_prf::validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, - Gate, RecordId, + ipa_prf::{ + validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, + LargeProofGenerator, SmallProofGenerator, + }, + Gate, RecordId, RecordIdRange, }, seq_join::{seq_join, SeqJoin}, sharding::ShardBinding, @@ -52,6 +55,26 @@ pub const TARGET_PROOF_SIZE: usize = 8192; #[cfg(not(test))] pub const TARGET_PROOF_SIZE: usize = 50_000_000; +/// Maximum proof recursion depth. +// +// This is a hard limit. Each GF(2) multiply generates four G values and four H values, +// and the last level of the proof is limited to (small_recursion_factor - 1), so the +// restriction is: +// +// $$ large_recursion_factor * (small_recursion_factor - 1) +// * small_recursion_factor ^ (depth - 2) >= 4 * target_proof_size $$ +// +// With large_recursion_factor = 32 and small_recursion_factor = 8, this means: +// +// $$ depth >= log_8 (8/7 * target_proof_size) $$ +// +// Because the number of records in a proof batch is often rounded up to a power of two +// (and less significantly, because multiplication intermediate storage gets rounded up +// to blocks of 256), leaving some margin is advised. +// +// The implementation requires that MAX_PROOF_RECURSION is at least 2. +pub const MAX_PROOF_RECURSION: usize = 9; + /// `MultiplicationInputsBlock` is a block of fixed size of intermediate values /// that occur duringa multiplication. /// These values need to be verified since there might have been malicious behavior. @@ -564,9 +587,22 @@ impl Batch { /// ## Panics /// If `usize` to `u128` conversion fails. - pub(super) async fn validate(self, ctx: Base<'_, B>) -> Result<(), Error> { + pub(super) async fn validate( + self, + ctx: Base<'_, B>, + batch_index: usize, + ) -> Result<(), Error> { + const PRSS_RECORDS_PER_BATCH: usize = LargeProofGenerator::PROOF_LENGTH + + (MAX_PROOF_RECURSION - 1) * SmallProofGenerator::PROOF_LENGTH + + 2; // P and Q masks + let proof_ctx = ctx.narrow(&Step::GenerateProof); + let record_id = RecordId::from(batch_index); + let prss_record_id_start = RecordId::from(batch_index * PRSS_RECORDS_PER_BATCH); + let prss_record_id_end = RecordId::from((batch_index + 1) * PRSS_RECORDS_PER_BATCH); + let prss_record_ids = RecordIdRange::from(prss_record_id_start..prss_record_id_end); + if self.is_empty() { return Ok(()); } @@ -578,11 +614,12 @@ impl Batch { q_mask_from_left_prover, ) = { // generate BatchToVerify - ProofBatch::generate(&proof_ctx, self.get_field_values_prover()) + ProofBatch::generate(&proof_ctx, prss_record_ids, self.get_field_values_prover()) }; let chunk_batch = BatchToVerify::generate_batch_to_verify( proof_ctx, + record_id, my_batch_left_shares, shares_of_batch_from_left_prover, p_mask_from_right_prover, @@ -592,7 +629,7 @@ impl Batch { // generate challenges let (challenges_for_left_prover, challenges_for_right_prover) = chunk_batch - .generate_challenges(ctx.narrow(&Step::Challenge)) + .generate_challenges(ctx.narrow(&Step::Challenge), record_id) .await; let (sum_of_uv, p_r_right_prover, q_r_left_prover) = { @@ -626,6 +663,7 @@ impl Batch { chunk_batch .verify( ctx.narrow(&Step::VerifyProof), + record_id, sum_of_uv, p_r_right_prover, q_r_left_prover, @@ -661,7 +699,18 @@ pub trait DZKPValidator: Send + Sync { /// /// # Panics /// May panic if the above restrictions on validator usage are not followed. - async fn validate(self) -> Result<(), Error>; + async fn validate(self) -> Result<(), Error> + where + Self: Sized, + { + self.validate_indexed(0).await + } + + /// Validates all of the multiplies associated with this validator, specifying + /// an explicit batch index. + /// + /// This should be used when the protocol is explicitly managing batches. + async fn validate_indexed(self, batch_index: usize) -> Result<(), Error>; /// `is_verified` checks that there are no `MultiplicationInputs` that have not been verified /// within the associated `DZKPBatch` @@ -707,6 +756,20 @@ pub trait DZKPValidator: Send + Sync { } } +// Wrapper to avoid https://github.com/rust-lang/rust/issues/100013. +pub fn validated_seq_join<'st, V, S, F, O>( + validator: V, + source: S, +) -> impl Stream> + Send + 'st +where + V: DZKPValidator + 'st, + S: Stream + Send + 'st, + F: Future> + Send + 'st, + O: Send + Sync + 'static, +{ + validator.validated_seq_join(source) +} + #[derive(Clone)] pub struct SemiHonestDZKPValidator<'a, B: ShardBinding> { context: SemiHonestDZKPUpgraded<'a, B>, @@ -732,7 +795,7 @@ impl<'a, B: ShardBinding> DZKPValidator for SemiHonestDZKPValidator<'a, B> { // Semi-honest validator doesn't do anything, so doesn't care. } - async fn validate(self) -> Result<(), Error> { + async fn validate_indexed(self, _batch_index: usize) -> Result<(), Error> { Ok(()) } @@ -778,7 +841,7 @@ impl<'a, B: ShardBinding> DZKPValidator for MaliciousDZKPValidator<'a, B> { .set_total_records(total_records); } - async fn validate(mut self) -> Result<(), Error> { + async fn validate_indexed(mut self, batch_index: usize) -> Result<(), Error> { let arc = self .inner_ref .take() @@ -793,7 +856,7 @@ impl<'a, B: ShardBinding> DZKPValidator for MaliciousDZKPValidator<'a, B> { batcher .into_single_batch() - .validate(validate_ctx.narrow(&DzkpSingleBatchStep)) + .validate(validate_ctx, batch_index) .await } @@ -1262,16 +1325,12 @@ mod tests { } proptest! { - #![proptest_config(ProptestConfig::with_cases(50))] + #![proptest_config(ProptestConfig::with_cases(20))] #[test] fn batching_proptest((record_count, max_multiplications_per_gate) in batching()) { println!("record_count {record_count} batch {max_multiplications_per_gate}"); - if record_count / max_multiplications_per_gate >= 192 { - // TODO: #1269, or even if we don't fix that, don't hardcode the limit. - println!("skipping config because batch count exceeds limit of 192"); - } // This condition is correct only for active_work = 16 and record size of 1 byte. - else if max_multiplications_per_gate != 1 && max_multiplications_per_gate % 16 != 0 { + if max_multiplications_per_gate != 1 && max_multiplications_per_gate % 16 != 0 { // TODO: #1300, read_size | batch_size. // Note: for active work < 2048, read size matches active work. diff --git a/ipa-core/src/protocol/context/step.rs b/ipa-core/src/protocol/context/step.rs index aeb6bd76f..24a8872be 100644 --- a/ipa-core/src/protocol/context/step.rs +++ b/ipa-core/src/protocol/context/step.rs @@ -28,18 +28,6 @@ pub(crate) enum ValidateStep { CheckZero, } -// This really is only for DZKPs and not for MACs. The MAC protocol uses record IDs to -// count batches. DZKP probably should do the same to avoid the fixed upper limit. -#[derive(CompactStep)] -#[step(count = 600, child = DzkpValidationProtocolStep)] -pub(crate) struct DzkpBatchStep(pub usize); - -// This is used when we don't do batched verification, to avoid paying for x256 as many -// steps in compact gate. -#[derive(CompactStep)] -#[step(child = DzkpValidationProtocolStep)] -pub(crate) struct DzkpSingleBatchStep; - #[derive(CompactStep)] pub(crate) enum DzkpValidationProtocolStep { /// Step for proof generation diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs index 65ec97230..be08c6593 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs @@ -105,7 +105,7 @@ where let validator = ctx.clone().dzkp_validator( MaliciousProtocolSteps { protocol: &Step::aggregate(depth), - validate: &Step::aggregate_validate(chunk_counter), + validate: &Step::AggregateValidate, }, // We have to specify usize::MAX here because the procession through // record IDs is different at each step of the reduction. The batch @@ -119,7 +119,7 @@ where Some(&mut record_ids), ) .await?; - validator.validate().await?; + validator.validate_indexed(chunk_counter).await?; chunk_counter += 1; next_intermediate_results.push(result); } diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs index 0995a8e54..8be4fdcd1 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs @@ -10,17 +10,17 @@ pub(crate) enum AggregationStep { #[step(child = crate::protocol::ipa_prf::shuffle::step::OPRFShuffleStep)] Shuffle, Reveal, - #[step(child = crate::protocol::context::step::DzkpSingleBatchStep)] + #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] RevealValidate, // only partly used -- see code - #[step(count = 4, child = AggregateChunkStep)] + #[step(count = 4, child = AggregateChunkStep, name = "chunks")] Aggregate(usize), - #[step(count = 600, child = crate::protocol::context::step::DzkpSingleBatchStep)] - AggregateValidate(usize), + #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] + AggregateValidate, } // The step count here is duplicated as the AGGREGATE_DEPTH constant in the code. #[derive(CompactStep)] -#[step(count = 24, child = AggregateValuesStep, name = "depth")] +#[step(count = 24, child = AggregateValuesStep, name = "fold")] pub(crate) struct AggregateChunkStep(usize); #[derive(CompactStep)] diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 2477f0867..32dbc3a18 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -50,6 +50,16 @@ where } } +impl Default for CanonicalLagrangeDenominator +where + F: PrimeField + TryFrom, + >::Error: Debug, +{ + fn default() -> Self { + Self::new() + } +} + /// `LagrangeTable` is a precomputed table for the Lagrange evaluation. /// Allows to compute points on the polynomial, i.e. output points, /// given enough points on the polynomial, i.e. input points, diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs index 780eabcf2..808dc4476 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs @@ -3,14 +3,14 @@ use std::{borrow::Borrow, iter::zip, marker::PhantomData}; #[cfg(all(test, unit_test))] use crate::ff::Fp31; use crate::{ - error::{Error, Error::DZKPMasks}, + error::Error::{self, DZKPMasks}, ff::{Fp61BitPrime, PrimeField}, helpers::hashing::{compute_hash, hash_to_field}, protocol::{ context::Context, ipa_prf::malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, prss::SharedRandomness, - RecordId, + RecordId, RecordIdRange, }, }; @@ -179,7 +179,7 @@ impl ProofGenerat .collect::>() } - fn gen_proof_shares_from_prss(ctx: &C, record_counter: &mut RecordId) -> ([F; P], [F; P]) + fn gen_proof_shares_from_prss(ctx: &C, record_ids: &mut RecordIdRange) -> ([F; P], [F; P]) where C: Context, { @@ -187,9 +187,9 @@ impl ProofGenerat let mut out_right = [F::ZERO; P]; // use PRSS for i in 0..P { - let (left, right) = ctx.prss().generate_fields::(*record_counter); - - *record_counter += 1; + let (left, right) = ctx + .prss() + .generate_fields::(record_ids.expect_next()); out_left[i] = left; out_right[i] = right; @@ -215,7 +215,7 @@ impl ProofGenerat /// `my_proof_left_share` has type `Vec<[F; P]>`, pub fn gen_artefacts_from_recursive_step( ctx: &C, - record_counter: &mut RecordId, + record_ids: &mut RecordIdRange, lagrange_table: &LagrangeTable, uv_iterator: J, ) -> (UVValues, [F; P], [F; P]) @@ -230,7 +230,7 @@ impl ProofGenerat // generate proof shares from prss let (share_of_proof_from_prover_left, my_proof_right_share) = - Self::gen_proof_shares_from_prss(ctx, record_counter); + Self::gen_proof_shares_from_prss(ctx, record_ids); // generate prover left proof let my_proof_left_share = Self::gen_other_proof_share(my_proof, my_proof_right_share); @@ -267,7 +267,7 @@ mod test { lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, prover::{LargeProofGenerator, SmallProofGenerator, TestProofGenerator, UVValues}, }, - RecordId, + RecordId, RecordIdRange, }, seq_join::SeqJoin, test_executor::run, @@ -396,11 +396,11 @@ mod test { // first iteration let world = TestWorld::default(); - let mut record_counter = RecordId::from(0); + let mut record_ids = RecordIdRange::ALL; let (uv_values, _, _) = TestProofGenerator::gen_artefacts_from_recursive_step::<_, _, _, 4>( &world.contexts()[0], - &mut record_counter, + &mut record_ids, &lagrange_table, uv_1.iter(), ); @@ -496,11 +496,11 @@ mod test { let world = TestWorld::default(); let [helper_1_proofs, helper_2_proofs, helper_3_proofs] = world .semi_honest((), |ctx, ()| async move { - let mut record_counter = RecordId::from(0); + let mut record_ids = RecordIdRange::ALL; (0..NUM_PROOFS) .map(|i| { - assert_eq!(i * 7, usize::from(record_counter)); - TestProofGenerator::gen_proof_shares_from_prss(&ctx, &mut record_counter) + assert_eq!(i * 7, usize::from(record_ids.peek_first())); + TestProofGenerator::gen_proof_shares_from_prss(&ctx, &mut record_ids) }) .collect::>() }) @@ -550,9 +550,9 @@ mod test { let [(h1_proof_left, h1_proof_right), (h2_proof_left, h2_proof_right), (h3_proof_left, h3_proof_right)] = world .semi_honest((), |ctx, ()| async move { - let mut record_counter = RecordId::from(0); + let mut record_ids = RecordIdRange::ALL; let (proof_share_left, my_share_of_right) = - TestProofGenerator::gen_proof_shares_from_prss(&ctx, &mut record_counter); + TestProofGenerator::gen_proof_shares_from_prss(&ctx, &mut record_ids); let proof_u128 = match ctx.role() { Role::H1 => PROOF_1, Role::H2 => PROOF_2, diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 9626731e4..559705df5 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -55,6 +55,8 @@ pub(crate) mod shuffle; pub(crate) mod step; pub mod validation_protocol; +pub use malicious_security::prover::{LargeProofGenerator, SmallProofGenerator}; + /// Match key type pub type MatchKey = BA64; /// Match key size @@ -755,7 +757,7 @@ mod compact_gate_tests { fn step_count_limit() { // This is an arbitrary limit intended to catch changes that unintentionally // blow up the step count. It can be increased, within reason. - const STEP_COUNT_LIMIT: u32 = 35_000; + const STEP_COUNT_LIMIT: u32 = 24_000; assert!( ProtocolStep::STEP_COUNT < STEP_COUNT_LIMIT, "Step count of {actual} exceeds limit of {STEP_COUNT_LIMIT}.", diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs index 710b0a7e3..d3f123bf3 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs @@ -8,7 +8,7 @@ pub struct UserNthRowStep(usize); pub(crate) enum AttributionStep { #[step(child = UserNthRowStep)] Attribute, - #[step(child = crate::protocol::context::step::DzkpBatchStep)] + #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] AttributeValidate, #[step(child = crate::protocol::ipa_prf::aggregation::step::AggregationStep)] Aggregate, diff --git a/ipa-core/src/protocol/ipa_prf/quicksort.rs b/ipa-core/src/protocol/ipa_prf/quicksort.rs index 30218e5c2..943dfb1ec 100644 --- a/ipa-core/src/protocol/ipa_prf/quicksort.rs +++ b/ipa-core/src/protocol/ipa_prf/quicksort.rs @@ -15,8 +15,8 @@ use crate::{ basics::reveal, boolean::{step::ThirtyTwoBitStep, NBitStep}, context::{ - dzkp_validator::DZKPValidator, Context, DZKPUpgraded, MaliciousProtocolSteps, - UpgradableContext, + dzkp_validator::{validated_seq_join, DZKPValidator, TARGET_PROOF_SIZE}, + Context, DZKPUpgraded, MaliciousProtocolSteps, UpgradableContext, }, ipa_prf::{ boolean_ops::comparison_and_subtraction_sequential::compare_gt, @@ -97,6 +97,10 @@ where } } +fn quicksort_proof_chunk(key_bits: usize) -> usize { + (TARGET_PROOF_SIZE / key_bits / SORT_CHUNK).next_power_of_two() +} + /// Insecure quicksort using MPC comparisons and a key extraction function `get_key`. /// /// `get_key` takes as input an element in the slice and outputs the key by which we sort by @@ -174,9 +178,7 @@ where protocol: &Step::quicksort_pass(quicksort_pass), validate: &Step::quicksort_pass_validate(quicksort_pass), }, - // TODO: use something like this when validating in chunks - // `TARGET_PROOF_SIZE / usize::try_from(K::BITS).unwrap() / SORT_CHUNK`` - total_records_usize.next_power_of_two(), + quicksort_proof_chunk(usize::try_from(K::BITS).unwrap()), ); let c = v.context(); let cmp_ctx = c.narrow(&QuicksortPassStep::Compare); @@ -186,7 +188,7 @@ where stream::iter(ranges_to_sort.clone().into_iter().filter(|r| r.len() > 1)) .flat_map(|range| { // set up iterator - let mut iterator = list[range.clone()].iter().map(get_key).cloned(); + let mut iterator = list[range].iter().map(get_key).cloned(); // first element is pivot, apply key extraction function f let pivot = iterator.next().unwrap(); repeat(pivot).zip(stream::iter(iterator)) @@ -197,8 +199,8 @@ where K::BITS <= ThirtyTwoBitStep::BITS, "ThirtyTwoBitStep is not large enough to accommodate this sort" ); - let compare_results = seq_join( - ctx.active_work(), + let compare_results = validated_seq_join( + v, process_stream_by_chunks::<_, _, _, _, _, _, SORT_CHUNK>( compare_index_pairs, (Vec::new(), Vec::new()), @@ -218,9 +220,6 @@ where .try_collect::>() .await?; - // TODO: validate in chunks rather than for the entire input - v.validate().await?; - let revealed: BitVec = seq_join( ctx.active_work(), stream::iter(compare_results).enumerate().map(|(i, chunk)| { @@ -275,7 +274,7 @@ where #[cfg(all(test, unit_test))] pub mod tests { use std::{ - cmp::Ordering, + cmp::{min, Ordering}, iter::{repeat, repeat_with}, }; @@ -392,6 +391,57 @@ pub mod tests { }); } + #[test] + fn test_quicksort_insecure_malicious_batching() { + run(|| async move { + const COUNT: usize = 600; + let world = TestWorld::default(); + let mut rng = thread_rng(); + + // generate vector of random values + let records: Vec = repeat_with(|| rng.gen()).take(COUNT).collect(); + + // Smaller ranges means fewer passes, makes the test faster. + // (With no impact on proof size, because there is a proof per pass.) + let ranges = (0..COUNT) + .step_by(8) + .map(|i| i..min(i + 8, COUNT)) + .collect::>(); + + // convert expected into more readable format + let mut expected: Vec = + records.clone().into_iter().map(|x| x.as_u128()).collect(); + // sort expected + for range in ranges.iter().cloned() { + expected[range].sort_unstable(); + } + + // compute mpc sort + let result: Vec<_> = world + .malicious(records.into_iter(), |ctx, mut r| { + let ranges_copy = ranges.clone(); + async move { + #[allow(clippy::single_range_in_vec_init)] + quicksort_ranges_by_key_insecure(ctx, &mut r, false, |x| x, ranges_copy) + .await + .unwrap(); + r + } + }) + .await + .reconstruct(); + + assert_eq!( + // convert into more readable format + result + .into_iter() + .map(|x| x.as_u128()) + .collect::>(), + expected + ); + }); + } + #[test] fn test_quicksort_insecure_semi_honest_trivial() { run(|| async move { diff --git a/ipa-core/src/protocol/ipa_prf/step.rs b/ipa-core/src/protocol/ipa_prf/step.rs index 633f157cc..5020a7185 100644 --- a/ipa-core/src/protocol/ipa_prf/step.rs +++ b/ipa-core/src/protocol/ipa_prf/step.rs @@ -8,7 +8,7 @@ pub(crate) enum IpaPrfStep { Shuffle, #[step(child = crate::protocol::ipa_prf::boolean_ops::step::Fp25519ConversionStep)] ConvertFp25519, - #[step(child = crate::protocol::context::step::DzkpBatchStep)] + #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] ConvertFp25519Validate, PrfKeyGen, #[step(child = crate::protocol::context::step::MaliciousProtocolStep)] @@ -19,7 +19,7 @@ pub(crate) enum IpaPrfStep { Attribution, #[step(child = crate::protocol::dp::step::DPStep, name = "dp")] DifferentialPrivacy, - #[step(child = crate::protocol::context::step::DzkpSingleBatchStep)] + #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] DifferentialPrivacyValidate, } @@ -28,7 +28,7 @@ pub(crate) enum QuicksortStep { /// Sort up to 1B rows. We can't exceed that limit for other reasons as well `record_id`. #[step(count = 30, child = crate::protocol::ipa_prf::step::QuicksortPassStep)] QuicksortPass(usize), - #[step(count = 30, child = crate::protocol::context::step::DzkpSingleBatchStep)] + #[step(count = 30, child = crate::protocol::context::step::DzkpValidationProtocolStep)] QuicksortPassValidate(usize), } diff --git a/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs b/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs index 5eccdc084..cb2754e5f 100644 --- a/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs +++ b/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs @@ -1,10 +1,16 @@ +use std::{array, iter::zip}; + +use typenum::{UInt, UTerm, Unsigned, B0, B1}; + use crate::{ + const_assert_eq, error::Error, - ff::Fp61BitPrime, - helpers::{Direction, TotalRecords}, + ff::{Fp61BitPrime, Serializable}, + helpers::{Direction, MpcMessage, TotalRecords}, protocol::{ context::{ dzkp_field::{UVTupleBlock, BLOCK_SIZE}, + dzkp_validator::MAX_PROOF_RECURSION, Context, }, ipa_prf::malicious_security::{ @@ -12,8 +18,9 @@ use crate::{ prover::{LargeProofGenerator, SmallProofGenerator}, }, prss::SharedRandomness, - RecordId, + RecordId, RecordIdRange, }, + secret_sharing::SharedValue, }; /// This a `ProofBatch` generated by a prover. @@ -47,11 +54,18 @@ impl ProofBatch { self.proofs.len() * SmallProofGenerator::PROOF_LENGTH + LargeProofGenerator::PROOF_LENGTH } - /// This function returns an iterator over the field elements of all proofs. - fn iter(&self) -> impl Iterator { - self.first_proof + #[allow(clippy::unnecessary_box_returns)] // clippy bug? `Array` exceeds unnecessary-box-size + fn to_array(&self) -> Box { + assert!(self.len() <= ARRAY_LEN); + let iter = self + .first_proof .iter() - .chain(self.proofs.iter().flat_map(|x| x.iter())) + .chain(self.proofs.iter().flat_map(|x| x.iter())); + let mut array = Box::new(array::from_fn(|_| Fp61BitPrime::ZERO)); + for (i, v) in iter.enumerate() { + array[i] = *v; + } + array } /// Each helper party generates a set of proofs, which are secret-shared. @@ -66,7 +80,11 @@ impl ProofBatch { /// ## Panics /// Panics when the function fails to set the masks without overwritting `u` and `v` values. /// This only happens when there is an issue in the recursion. - pub fn generate(ctx: &C, uv_tuple_inputs: I) -> (Self, Self, Fp61BitPrime, Fp61BitPrime) + pub fn generate( + ctx: &C, + mut prss_record_ids: RecordIdRange, + uv_tuple_inputs: I, + ) -> (Self, Self, Fp61BitPrime, Fp61BitPrime) where C: Context, I: Iterator> + Clone, @@ -77,9 +95,6 @@ impl ProofBatch { const SLL: usize = SmallProofGenerator::LAGRANGE_LENGTH; const SPL: usize = SmallProofGenerator::PROOF_LENGTH; - // set up record counter - let mut record_counter = RecordId::FIRST; - // precomputation for first proof let first_denominator = CanonicalLagrangeDenominator::::new(); let first_lagrange_table = LagrangeTable::::from(first_denominator); @@ -88,32 +103,40 @@ impl ProofBatch { let (mut uv_values, first_proof_from_left, my_first_proof_left_share) = LargeProofGenerator::gen_artefacts_from_recursive_step( ctx, - &mut record_counter, + &mut prss_record_ids, &first_lagrange_table, ProofBatch::polynomials_from_inputs(uv_tuple_inputs), ); - // approximate length of proof vector (rounded up) - let uv_len_bits: u32 = usize::BITS - uv_values.len().leading_zeros(); - let small_recursion_factor_bits: u32 = usize::BITS - SRF.leading_zeros(); - let expected_len = 1 << (uv_len_bits - small_recursion_factor_bits); + // `MAX_PROOF_RECURSION - 2` because: + // * The first level of recursion has already happened. + // * We need (SRF - 1) at the last level to have room for the masks. + let max_uv_values: usize = + (SRF - 1) * SRF.pow(u32::try_from(MAX_PROOF_RECURSION - 2).unwrap()); + assert!( + uv_values.len() <= max_uv_values, + "Proof batch is too large: have {} uv_values, max is {}", + uv_values.len(), + max_uv_values, + ); // storage for other proofs - let mut my_proofs_left_shares = Vec::<[Fp61BitPrime; SPL]>::with_capacity(expected_len); + let mut my_proofs_left_shares = + Vec::<[Fp61BitPrime; SPL]>::with_capacity(MAX_PROOF_RECURSION - 1); let mut shares_of_proofs_from_prover_left = - Vec::<[Fp61BitPrime; SPL]>::with_capacity(expected_len); + Vec::<[Fp61BitPrime; SPL]>::with_capacity(MAX_PROOF_RECURSION - 1); // generate masks // Prover `P_i` and verifier `P_{i-1}` both compute p(x) // therefore the "right" share computed by this verifier corresponds to that which // was used by the prover to the right. - let (my_p_mask, p_mask_from_right_prover) = ctx.prss().generate_fields(record_counter); - record_counter += 1; + let (my_p_mask, p_mask_from_right_prover) = + ctx.prss().generate_fields(prss_record_ids.expect_next()); // Prover `P_i` and verifier `P_{i+1}` both compute q(x) // therefore the "left" share computed by this verifier corresponds to that which // was used by the prover to the left. - let (q_mask_from_left_prover, my_q_mask) = ctx.prss().generate_fields(record_counter); - record_counter += 1; + let (q_mask_from_left_prover, my_q_mask) = + ctx.prss().generate_fields(prss_record_ids.expect_next()); let denominator = CanonicalLagrangeDenominator::::new(); let lagrange_table = LagrangeTable::::from(denominator); @@ -135,7 +158,7 @@ impl ProofBatch { let (uv_values_new, share_of_proof_from_prover_left, my_proof_left_share) = SmallProofGenerator::gen_artefacts_from_recursive_step( ctx, - &mut record_counter, + &mut prss_record_ids, &lagrange_table, uv_values.iter(), ); @@ -165,52 +188,40 @@ impl ProofBatch { /// /// ## Errors /// Propagates error from sending values over the network channel. - pub async fn send_to_left(&self, ctx: &C) -> Result<(), Error> + pub async fn send_to_left(&self, ctx: &C, record_id: RecordId) -> Result<(), Error> where C: Context, { - // set up context for the communication over the network - let communication_ctx = ctx.set_total_records(TotalRecords::specified(self.len())?); - - // set up channel - let send_channel_left = - &communication_ctx.send_channel::(ctx.role().peer(Direction::Left)); - - // send to left - // we send the proof batch via sending the individual field elements - communication_ctx - .parallel_join( - self.iter().enumerate().map(|(i, x)| async move { - send_channel_left.send(RecordId::from(i), x).await - }), - ) - .await?; - Ok(()) + Ok(ctx + .set_total_records(TotalRecords::Indeterminate) + .send_channel::>(ctx.role().peer(Direction::Left)) + .send(record_id, self.to_array()) + .await?) } /// This function receives a `Proof` from the party on the right. /// /// ## Errors /// Propagates errors from receiving values over the network channel. - pub async fn receive_from_right(ctx: &C, length: usize) -> Result + /// + /// ## Panics + /// If the recursion depth implied by `length` exceeds `MAX_PROOF_RECURSION`. + pub async fn receive_from_right( + ctx: &C, + record_id: RecordId, + length: usize, + ) -> Result where C: Context, { - // set up context - let communication_ctx = ctx.set_total_records(TotalRecords::specified(length)?); - - // set up channel - let receive_channel_right = - &communication_ctx.recv_channel::(ctx.role().peer(Direction::Right)); - - // receive from the right + assert!(length <= ARRAY_LEN); Ok(ctx - .parallel_join( - (0..length) - .map(|i| async move { receive_channel_right.receive(RecordId::from(i)).await }), - ) + .set_total_records(TotalRecords::Indeterminate) + .recv_channel::>(ctx.role().peer(Direction::Right)) + .receive(record_id) .await? .into_iter() + .take(length) .collect()) } @@ -251,6 +262,47 @@ impl ProofBatch { } } +const_assert_eq!( + MAX_PROOF_RECURSION, + 9, + "following impl valid only for MAX_PROOF_RECURSION = 9" +); + +#[rustfmt::skip] +type U1464 = UInt, B0>, B1>, B1>, B0>, B1>, B1>, B1>, B0>, B0>, B0>; + +const ARRAY_LEN: usize = 183; +type Array = [Fp61BitPrime; ARRAY_LEN]; + +impl Serializable for Box { + type Size = U1464; + + type DeserializationError = ::DeserializationError; + + fn serialize(&self, buf: &mut generic_array::GenericArray) { + for (hash, buf) in zip( + **self, + buf.chunks_mut(<::Size as Unsigned>::to_usize()), + ) { + hash.serialize(buf.try_into().unwrap()); + } + } + + fn deserialize( + buf: &generic_array::GenericArray, + ) -> Result { + Ok(buf + .chunks(<::Size as Unsigned>::to_usize()) + .map(|buf| Fp61BitPrime::deserialize(buf.try_into().unwrap())) + .collect::, _>>()? + .try_into() + .unwrap()) + } +} + +impl MpcMessage for Box {} + #[cfg(all(test, unit_test))] mod test { use rand::{thread_rng, Rng}; @@ -263,6 +315,7 @@ mod test { proof_generation::ProofBatch, validation::{test::simple_proof_check, BatchToVerify}, }, + RecordId, RecordIdRange, }, secret_sharing::replicated::ReplicatedSecretSharing, test_executor::run, @@ -312,11 +365,13 @@ mod test { q_mask_from_left_prover, ) = ProofBatch::generate( &ctx.narrow("generate_batch"), + RecordIdRange::ALL, uv_tuple_vec.into_iter(), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( ctx.narrow("generate_batch"), + RecordId::FIRST, my_batch_left_shares, shares_of_batch_from_left_prover, p_mask_from_right_prover, diff --git a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs index bbb994f70..f0430e996 100644 --- a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs +++ b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs @@ -1,16 +1,23 @@ -use std::iter::{once, repeat}; +use std::{ + array, + iter::{once, repeat, zip}, +}; use futures_util::future::{try_join, try_join4}; +use typenum::{Unsigned, U288, U80}; use crate::{ - error::Error, - ff::Fp61BitPrime, + const_assert_eq, + error::{Error, UnwrapInfallible}, + ff::{Fp61BitPrime, Serializable}, helpers::{ hashing::{compute_hash, hash_to_field, Hash}, - Direction, TotalRecords, + Direction, MpcMessage, TotalRecords, }, protocol::{ - context::{step::DzkpProofVerifyStep as Step, Context}, + context::{ + dzkp_validator::MAX_PROOF_RECURSION, step::DzkpProofVerifyStep as Step, Context, + }, ipa_prf::{ malicious_security::{ prover::{LargeProofGenerator, SmallProofGenerator}, @@ -55,6 +62,7 @@ impl BatchToVerify { /// Panics when send and receive over the network channels fail. pub async fn generate_batch_to_verify( ctx: C, + record_id: RecordId, my_batch_left_shares: ProofBatch, shares_of_batch_from_left_prover: ProofBatch, p_mask_from_right_prover: Fp61BitPrime, @@ -66,8 +74,8 @@ impl BatchToVerify { // send one batch left and receive one batch from the right let length = my_batch_left_shares.len(); let ((), shares_of_batch_from_right_prover) = try_join( - my_batch_left_shares.send_to_left(&ctx), - ProofBatch::receive_from_right(&ctx, length), + my_batch_left_shares.send_to_left(&ctx, record_id), + ProofBatch::receive_from_right(&ctx, record_id, length), ) .await .unwrap(); @@ -88,7 +96,11 @@ impl BatchToVerify { /// ## Panics /// Panics when recursion factor constant cannot be converted to `u128` /// or when sending and receiving hashes over the network fails. - pub async fn generate_challenges(&self, ctx: C) -> (Vec, Vec) + pub async fn generate_challenges( + &self, + ctx: C, + record_id: RecordId, + ) -> (Vec, Vec) where C: Context, { @@ -101,15 +113,25 @@ impl BatchToVerify { let exclude_small = u128::try_from(SRF).unwrap(); // generate hashes - let my_hashes_prover_left = ProofHashes::generate_hashes(self, Side::Left); - let my_hashes_prover_right = ProofHashes::generate_hashes(self, Side::Right); + let my_hashes_prover_left = ProofHashes::generate_hashes(self, Direction::Left); + let my_hashes_prover_right = ProofHashes::generate_hashes(self, Direction::Right); // receive hashes from the other verifier let ((), (), other_hashes_prover_left, other_hashes_prover_right) = try_join4( - my_hashes_prover_left.send_hashes(&ctx, Side::Left), - my_hashes_prover_right.send_hashes(&ctx, Side::Right), - ProofHashes::receive_hashes(&ctx, my_hashes_prover_left.hashes.len(), Side::Left), - ProofHashes::receive_hashes(&ctx, my_hashes_prover_right.hashes.len(), Side::Right), + my_hashes_prover_left.send_hashes(&ctx, record_id, Direction::Left), + my_hashes_prover_right.send_hashes(&ctx, record_id, Direction::Right), + ProofHashes::receive_hashes( + &ctx, + record_id, + my_hashes_prover_left.hashes.len(), + Direction::Left, + ), + ProofHashes::receive_hashes( + &ctx, + record_id, + my_hashes_prover_right.hashes.len(), + Direction::Right, + ), ) .await .unwrap(); @@ -174,6 +196,7 @@ impl BatchToVerify { /// This function computes and outputs the final `p_r_right_prover * q_r_right_prover` value. async fn compute_p_times_q( ctx: C, + record_id: RecordId, p_r_right_prover: Fp61BitPrime, q_r_left_prover: Fp61BitPrime, ) -> Result @@ -181,7 +204,7 @@ impl BatchToVerify { C: Context, { // send to the left - let communication_ctx = ctx.set_total_records(TotalRecords::specified(1usize)?); + let communication_ctx = ctx.set_total_records(TotalRecords::Indeterminate); let send_right = communication_ctx.send_channel::(ctx.role().peer(Direction::Right)); @@ -189,8 +212,8 @@ impl BatchToVerify { communication_ctx.recv_channel::(ctx.role().peer(Direction::Left)); let ((), q_r_right_prover) = try_join( - send_right.send(RecordId::FIRST, q_r_left_prover), - receive_left.receive(RecordId::FIRST), + send_right.send(record_id, q_r_left_prover), + receive_left.receive(record_id), ) .await?; @@ -201,9 +224,14 @@ impl BatchToVerify { /// /// ## Errors /// Propagates network errors or when the proof fails to verify. + /// + /// ## Panics + /// If the proof exceeds `MAX_PROOF_RECURSION`. + #[allow(clippy::too_many_arguments)] pub async fn verify( &self, ctx: C, + record_id: RecordId, sum_of_uv_right: Fp61BitPrime, p_r_right_prover: Fp61BitPrime, q_r_left_prover: Fp61BitPrime, @@ -221,6 +249,7 @@ impl BatchToVerify { let p_times_q_right = Self::compute_p_times_q( ctx.narrow(&Step::PTimesQ), + record_id, p_r_right_prover, q_r_left_prover, ) @@ -243,33 +272,26 @@ impl BatchToVerify { p_times_q_right, ); - // send dif_left to the right + // send diff_left to the right let length = diff_left.len(); + assert!(length <= MAX_PROOF_RECURSION + 1); + let communication_ctx = ctx .narrow(&Step::Diff) - .set_total_records(TotalRecords::specified(length)?); - - let send_channel = - communication_ctx.send_channel::(ctx.role().peer(Direction::Right)); - let receive_channel = - communication_ctx.recv_channel::(ctx.role().peer(Direction::Left)); + .set_total_records(TotalRecords::Indeterminate); - let send_channel_ref = &send_channel; - let receive_channel_ref = &receive_channel; + let send_data = array::from_fn(|i| *diff_left.get(i).unwrap_or(&Fp61BitPrime::ZERO)); - let send_future = communication_ctx.parallel_join( - diff_left - .iter() - .enumerate() - .map(|(i, f)| async move { send_channel_ref.send(RecordId::from(i), f).await }), - ); - - let receive_future = communication_ctx.parallel_join( - (0..length) - .map(|i| async move { receive_channel_ref.receive(RecordId::from(i)).await }), - ); - - let (_, diff_right_from_other_verifier) = try_join(send_future, receive_future).await?; + let ((), receive_data) = try_join( + communication_ctx + .send_channel::(ctx.role().peer(Direction::Right)) + .send(record_id, send_data), + communication_ctx + .recv_channel::(ctx.role().peer(Direction::Left)) + .receive(record_id), + ) + .await?; + let diff_right_from_other_verifier = receive_data[0..length].to_vec(); // compare recombined dif to zero for i in 0..length { @@ -286,21 +308,15 @@ struct ProofHashes { hashes: Vec, } -#[derive(Clone, Copy, Debug)] -enum Side { - Left, - Right, -} - impl ProofHashes { - // Generates hashes for proofs received from prover indicated by `side` - fn generate_hashes(batch_to_verify: &BatchToVerify, side: Side) -> Self { - let (first_proof, other_proofs) = match side { - Side::Left => ( + // Generates hashes for proofs received from prover indicated by `direction` + fn generate_hashes(batch_to_verify: &BatchToVerify, direction: Direction) -> Self { + let (first_proof, other_proofs) = match direction { + Direction::Left => ( &batch_to_verify.first_proof_from_left_prover, &batch_to_verify.proofs_from_left_prover, ), - Side::Right => ( + Direction::Right => ( &batch_to_verify.first_proof_from_right_prover, &batch_to_verify.proofs_from_right_prover, ), @@ -314,54 +330,116 @@ impl ProofHashes { } /// Sends the one verifier's hashes to the other verifier - /// `side` indicates the direction of the prover. - async fn send_hashes(&self, ctx: &C, side: Side) -> Result<(), Error> { - let communication_ctx = ctx.set_total_records(TotalRecords::specified(self.hashes.len())?); - - let send_channel = match side { - // send left hashes to the right - Side::Left => communication_ctx.send_channel::(ctx.role().peer(Direction::Right)), - // send right hashes to the left - Side::Right => communication_ctx.send_channel::(ctx.role().peer(Direction::Left)), - }; - let send_channel_ref = &send_channel; - - communication_ctx - .parallel_join(self.hashes.iter().enumerate().map(|(i, hash)| async move { - send_channel_ref.send(RecordId::from(i), hash).await - })) + /// `direction` indicates the direction of the prover. + async fn send_hashes( + &self, + ctx: &C, + record_id: RecordId, + direction: Direction, + ) -> Result<(), Error> { + assert!(self.hashes.len() <= MAX_PROOF_RECURSION); + let hashes_send = + array::from_fn(|i| self.hashes.get(i).unwrap_or(&Hash::default()).clone()); + let verifier_direction = !direction; + ctx.set_total_records(TotalRecords::Indeterminate) + .send_channel::<[Hash; MAX_PROOF_RECURSION]>(ctx.role().peer(verifier_direction)) + .send(record_id, hashes_send) .await?; Ok(()) } /// This function receives hashes from the other verifier - /// `side` indicates the direction of the prover. - async fn receive_hashes(ctx: &C, length: usize, side: Side) -> Result { - // set up context for the communication over the network - let communication_ctx = ctx.set_total_records(TotalRecords::specified(length)?); - - let recv_channel = match side { - // receive left hashes from the right helper - Side::Left => communication_ctx.recv_channel::(ctx.role().peer(Direction::Right)), - // reeive right hashes from the left helper - Side::Right => communication_ctx.recv_channel::(ctx.role().peer(Direction::Left)), - }; - let recv_channel_ref = &recv_channel; - - let hashes_received = communication_ctx - .parallel_join( - (0..length) - .map(|i| async move { recv_channel_ref.receive(RecordId::from(i)).await }), - ) + /// `direction` indicates the direction of the prover. + async fn receive_hashes( + ctx: &C, + record_id: RecordId, + length: usize, + direction: Direction, + ) -> Result { + assert!(length <= MAX_PROOF_RECURSION); + let verifier_direction = !direction; + let hashes_received = ctx + .set_total_records(TotalRecords::Indeterminate) + .recv_channel::<[Hash; MAX_PROOF_RECURSION]>(ctx.role().peer(verifier_direction)) + .receive(record_id) .await?; - Ok(Self { - hashes: hashes_received, + hashes: hashes_received[0..length].to_vec(), }) } } +const_assert_eq!( + MAX_PROOF_RECURSION, + 9, + "following impl valid only for MAX_PROOF_RECURSION = 9" +); + +impl Serializable for [Hash; MAX_PROOF_RECURSION] { + type Size = U288; + + type DeserializationError = ::DeserializationError; + + fn serialize(&self, buf: &mut generic_array::GenericArray) { + for (hash, buf) in zip( + self, + buf.chunks_mut(<::Size as Unsigned>::to_usize()), + ) { + hash.serialize(buf.try_into().unwrap()); + } + } + + fn deserialize( + buf: &generic_array::GenericArray, + ) -> Result { + Ok(buf + .chunks(<::Size as Unsigned>::to_usize()) + .map(|buf| Hash::deserialize(buf.try_into().unwrap()).unwrap_infallible()) + .collect::>() + .try_into() + .unwrap()) + } +} + +impl MpcMessage for [Hash; MAX_PROOF_RECURSION] {} + +const_assert_eq!( + MAX_PROOF_RECURSION, + 9, + "following impl valid only for MAX_PROOF_RECURSION = 9" +); + +type ProofDiff = [Fp61BitPrime; MAX_PROOF_RECURSION + 1]; + +impl Serializable for ProofDiff { + type Size = U80; + + type DeserializationError = ::DeserializationError; + + fn serialize(&self, buf: &mut generic_array::GenericArray) { + for (hash, buf) in zip( + self, + buf.chunks_mut(<::Size as Unsigned>::to_usize()), + ) { + hash.serialize(buf.try_into().unwrap()); + } + } + + fn deserialize( + buf: &generic_array::GenericArray, + ) -> Result { + Ok(buf + .chunks(<::Size as Unsigned>::to_usize()) + .map(|buf| Fp61BitPrime::deserialize(buf.try_into().unwrap())) + .collect::, _>>()? + .try_into() + .unwrap()) + } +} + +impl MpcMessage for ProofDiff {} + #[cfg(all(test, unit_test))] pub mod test { use futures_util::future::try_join; @@ -384,7 +462,7 @@ pub mod test { validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, }, prss::SharedRandomness, - RecordId, + RecordId, RecordIdRange, }, secret_sharing::{replicated::ReplicatedSecretSharing, SharedValue}, test_executor::run, @@ -528,11 +606,13 @@ pub mod test { q_mask_from_left_prover, ) = ProofBatch::generate( &ctx.narrow("generate_batch"), + RecordIdRange::ALL, uv_tuple_vec.into_iter(), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( ctx.narrow("generate_batch"), + RecordId::FIRST, my_batch_left_shares, shares_of_batch_from_left_prover, p_mask_from_right_prover, @@ -541,7 +621,9 @@ pub mod test { .await; // generate and output challenges - batch_to_verify.generate_challenges(ctx).await + batch_to_verify + .generate_challenges(ctx, RecordId::FIRST) + .await }) .await; @@ -639,11 +721,13 @@ pub mod test { q_mask_from_left_prover, ) = ProofBatch::generate( &ctx.narrow("generate_batch"), + RecordIdRange::ALL, vec_my_u_and_v.into_iter(), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( ctx.narrow("generate_batch"), + RecordId::FIRST, my_batch_left_shares, shares_of_batch_from_left_prover, p_mask_from_right_prover, @@ -654,7 +738,7 @@ pub mod test { // generate challenges let (challenges_for_left_prover, challenges_for_right_prover) = batch_to_verify - .generate_challenges(ctx.narrow("generate_hash")) + .generate_challenges(ctx.narrow("generate_hash"), RecordId::FIRST) .await; assert_eq!( @@ -743,11 +827,13 @@ pub mod test { q_mask_from_left_prover, ) = ProofBatch::generate( &ctx.narrow("generate_batch"), + RecordIdRange::ALL, vec_my_u_and_v.into_iter(), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( ctx.narrow("generate_batch"), + RecordId::FIRST, my_batch_left_shares, shares_of_batch_from_left_prover, p_mask_from_right_prover, @@ -758,7 +844,7 @@ pub mod test { // generate challenges let (challenges_for_left_prover, challenges_for_right_prover) = batch_to_verify - .generate_challenges(ctx.narrow("generate_hash")) + .generate_challenges(ctx.narrow("generate_hash"), RecordId::FIRST) .await; assert_eq!( @@ -773,7 +859,10 @@ pub mod test { vec_v_from_left_prover.into_iter(), ); - let p_times_q = BatchToVerify::compute_p_times_q(ctx, p, q).await.unwrap(); + let p_times_q = + BatchToVerify::compute_p_times_q(ctx, RecordId::FIRST, p, q) + .await + .unwrap(); let denominator = CanonicalLagrangeDenominator::< Fp61BitPrime, @@ -828,11 +917,13 @@ pub mod test { q_mask_from_left_prover, ) = ProofBatch::generate( &ctx.narrow("generate_batch"), + RecordIdRange::ALL, vec_my_u_and_v.into_iter(), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( ctx.narrow("generate_batch"), + RecordId::FIRST, my_batch_left_shares, shares_of_batch_from_left_prover, p_mask_from_right_prover, @@ -861,7 +952,7 @@ pub mod test { // generate challenges let (challenges_for_left_prover, challenges_for_right_prover) = batch_to_verify - .generate_challenges(ctx.narrow("generate_hash")) + .generate_challenges(ctx.narrow("generate_hash"), RecordId::FIRST) .await; let (p, q) = batch_to_verify.compute_p_and_q_r( @@ -874,6 +965,7 @@ pub mod test { batch_to_verify .verify( v_ctx, + RecordId::FIRST, sum_of_uv_right, p, q, diff --git a/ipa-core/src/protocol/mod.rs b/ipa-core/src/protocol/mod.rs index 28abf9741..18dfc6221 100644 --- a/ipa-core/src/protocol/mod.rs +++ b/ipa-core/src/protocol/mod.rs @@ -10,7 +10,7 @@ pub mod step; use std::{ fmt::{Debug, Display, Formatter}, hash::Hash, - ops::{Add, AddAssign}, + ops::{Add, AddAssign, Range}, }; pub use basics::{BasicProtocols, BooleanProtocols}; @@ -107,6 +107,7 @@ impl From for RecordId { impl RecordId { pub(crate) const FIRST: Self = Self(0); + pub(crate) const LAST: Self = Self(u32::MAX); } impl From for u128 { @@ -147,6 +148,30 @@ impl AddAssign for RecordId { } } +pub struct RecordIdRange(Range); + +impl RecordIdRange { + pub const ALL: RecordIdRange = RecordIdRange(RecordId::FIRST..RecordId::LAST); + + #[cfg(all(test, unit_test))] + fn peek_first(&self) -> RecordId { + self.0.start + } + + fn expect_next(&mut self) -> RecordId { + assert!(self.0.start < self.0.end, "RecordIdRange exhausted"); + let val = self.0.start; + self.0.start += 1; + val + } +} + +impl From> for RecordIdRange { + fn from(value: Range) -> Self { + Self(value) + } +} + /// Helper used when an operation may or may not be associated with a specific record. This is /// also used to prevent some kinds of invalid uses of record ID iteration. For example, trying to /// use the record ID to iterate over both the inner and outer vectors in a `Vec>` is an