Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use record IDs to index proof batches #1350

Merged
merged 7 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .clippy.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ disallowed-methods = [
{ path = "std::vec::Vec::leak", reason = "Not running the destructors on futures created inside seq_join module will cause UB in IPA. Make sure you don't leak any of those." },
]

future-size-threshold = 8192
future-size-threshold = 10240
2 changes: 1 addition & 1 deletion ipa-core/src/helpers/hashing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
protocol::prss::FromRandomU128,
};

#[derive(Debug, PartialEq)]
#[derive(Clone, Debug, Default, PartialEq)]
pub struct Hash(Output<Sha256>);

impl Serializable for Hash {
Expand Down
12 changes: 12 additions & 0 deletions ipa-core/src/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{
convert::Infallible,
fmt::{Debug, Display, Formatter},
num::NonZeroUsize,
ops::Not,
};

use generic_array::GenericArray;
Expand Down Expand Up @@ -267,6 +268,17 @@ pub enum Direction {
Right,
}

impl Not for Direction {
type Output = Self;

fn not(self) -> Self {
match self {
Direction::Left => Direction::Right,
Direction::Right => Direction::Left,
}
}
}

impl Role {
const H1_STR: &'static str = "H1";
const H2_STR: &'static str = "H2";
Expand Down
5 changes: 1 addition & 4 deletions ipa-core/src/protocol/context/dzkp_malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use crate::{
context::{
dzkp_validator::{Batch, MaliciousDZKPValidatorInner, Segment},
prss::InstrumentedIndexedSharedRandomness,
step::DzkpBatchStep,
Context as ContextTrait, DZKPContext, InstrumentedSequentialSharedRandomness,
MaliciousContext,
},
Expand Down Expand Up @@ -100,9 +99,7 @@ impl<'a, B: ShardBinding> DZKPContext for DZKPUpgraded<'a, B> {
.batcher
.lock()
.unwrap()
.validate_record(record_id, |batch_idx, batch| {
batch.validate(ctx.narrow(&DzkpBatchStep(batch_idx)))
});
.validate_record(record_id, |batch_idx, batch| batch.validate(ctx, batch_idx));

validation_future.await
}
Expand Down
91 changes: 75 additions & 16 deletions ipa-core/src/protocol/context/dzkp_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ use crate::{
dzkp_field::{DZKPBaseField, UVTupleBlock},
dzkp_malicious::DZKPUpgraded as MaliciousDZKPUpgraded,
dzkp_semi_honest::DZKPUpgraded as SemiHonestDZKPUpgraded,
step::{DzkpSingleBatchStep, DzkpValidationProtocolStep as Step},
step::DzkpValidationProtocolStep as Step,
Base, Context, DZKPContext, MaliciousContext, MaliciousProtocolSteps,
},
ipa_prf::validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify},
Gate, RecordId,
ipa_prf::{
validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify},
LargeProofGenerator, SmallProofGenerator,
},
Gate, RecordId, RecordIdRange,
},
seq_join::{seq_join, SeqJoin},
sharding::ShardBinding,
Expand Down Expand Up @@ -52,6 +55,26 @@ pub const TARGET_PROOF_SIZE: usize = 8192;
#[cfg(not(test))]
pub const TARGET_PROOF_SIZE: usize = 50_000_000;

/// Maximum proof recursion depth.
//
// This is a hard limit. Each GF(2) multiply generates four G values and four H values,
// and the last level of the proof is limited to (small_recursion_factor - 1), so the
// restriction is:
//
// $$ large_recursion_factor * (small_recursion_factor - 1)
// * small_recursion_factor ^ (depth - 2) >= 4 * target_proof_size $$
//
// With large_recursion_factor = 32 and small_recursion_factor = 8, this means:
//
// $$ depth >= log_8 (8/7 * target_proof_size) $$
//
// Because the number of records in a proof batch is often rounded up to a power of two
// (and less significantly, because multiplication intermediate storage gets rounded up
// to blocks of 256), leaving some margin is advised.
//
// The implementation requires that MAX_PROOF_RECURSION is at least 2.
pub const MAX_PROOF_RECURSION: usize = 9;
Copy link
Collaborator Author

@andyleiserson andyleiserson Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recursion depth of 9 corresponds to TARGET_PROOF_SIZE of about 117M. Without quicksort batching, this is not quite enough. So either we can wait for quicksort batching, or we can temporarily increase MAX_PROOF_RECURSION slightly. I lean towards the former unless we identify an urgency to take this change.


/// `MultiplicationInputsBlock` is a block of fixed size of intermediate values
/// that occur duringa multiplication.
/// These values need to be verified since there might have been malicious behavior.
Expand Down Expand Up @@ -564,9 +587,22 @@ impl Batch {

/// ## Panics
/// If `usize` to `u128` conversion fails.
pub(super) async fn validate<B: ShardBinding>(self, ctx: Base<'_, B>) -> Result<(), Error> {
pub(super) async fn validate<B: ShardBinding>(
self,
ctx: Base<'_, B>,
batch_index: usize,
) -> Result<(), Error> {
const PRSS_RECORDS_PER_BATCH: usize = LargeProofGenerator::PROOF_LENGTH
+ (MAX_PROOF_RECURSION - 1) * SmallProofGenerator::PROOF_LENGTH
+ 2; // P and Q masks

let proof_ctx = ctx.narrow(&Step::GenerateProof);

let record_id = RecordId::from(batch_index);
let prss_record_id_start = RecordId::from(batch_index * PRSS_RECORDS_PER_BATCH);
let prss_record_id_end = RecordId::from((batch_index + 1) * PRSS_RECORDS_PER_BATCH);
let prss_record_ids = RecordIdRange::from(prss_record_id_start..prss_record_id_end);

if self.is_empty() {
return Ok(());
}
Expand All @@ -578,11 +614,12 @@ impl Batch {
q_mask_from_left_prover,
) = {
// generate BatchToVerify
ProofBatch::generate(&proof_ctx, self.get_field_values_prover())
ProofBatch::generate(&proof_ctx, prss_record_ids, self.get_field_values_prover())
};

let chunk_batch = BatchToVerify::generate_batch_to_verify(
proof_ctx,
record_id,
my_batch_left_shares,
shares_of_batch_from_left_prover,
p_mask_from_right_prover,
Expand All @@ -592,7 +629,7 @@ impl Batch {

// generate challenges
let (challenges_for_left_prover, challenges_for_right_prover) = chunk_batch
.generate_challenges(ctx.narrow(&Step::Challenge))
.generate_challenges(ctx.narrow(&Step::Challenge), record_id)
.await;

let (sum_of_uv, p_r_right_prover, q_r_left_prover) = {
Expand Down Expand Up @@ -626,6 +663,7 @@ impl Batch {
chunk_batch
.verify(
ctx.narrow(&Step::VerifyProof),
record_id,
sum_of_uv,
p_r_right_prover,
q_r_left_prover,
Expand Down Expand Up @@ -661,7 +699,18 @@ pub trait DZKPValidator: Send + Sync {
///
/// # Panics
/// May panic if the above restrictions on validator usage are not followed.
async fn validate(self) -> Result<(), Error>;
async fn validate(self) -> Result<(), Error>
where
Self: Sized,
{
self.validate_indexed(0).await
}

/// Validates all of the multiplies associated with this validator, specifying
/// an explicit batch index.
///
/// This should be used when the protocol is explicitly managing batches.
async fn validate_indexed(self, batch_index: usize) -> Result<(), Error>;

/// `is_verified` checks that there are no `MultiplicationInputs` that have not been verified
/// within the associated `DZKPBatch`
Expand Down Expand Up @@ -707,6 +756,20 @@ pub trait DZKPValidator: Send + Sync {
}
}

// Wrapper to avoid https://github.com/rust-lang/rust/issues/100013.
pub fn validated_seq_join<'st, V, S, F, O>(
validator: V,
source: S,
) -> impl Stream<Item = Result<O, Error>> + Send + 'st
where
V: DZKPValidator + 'st,
S: Stream<Item = F> + Send + 'st,
F: Future<Output = Result<O, Error>> + Send + 'st,
O: Send + Sync + 'static,
{
validator.validated_seq_join(source)
}

#[derive(Clone)]
pub struct SemiHonestDZKPValidator<'a, B: ShardBinding> {
context: SemiHonestDZKPUpgraded<'a, B>,
Expand All @@ -732,7 +795,7 @@ impl<'a, B: ShardBinding> DZKPValidator for SemiHonestDZKPValidator<'a, B> {
// Semi-honest validator doesn't do anything, so doesn't care.
}

async fn validate(self) -> Result<(), Error> {
async fn validate_indexed(self, _batch_index: usize) -> Result<(), Error> {
Ok(())
}

Expand Down Expand Up @@ -778,7 +841,7 @@ impl<'a, B: ShardBinding> DZKPValidator for MaliciousDZKPValidator<'a, B> {
.set_total_records(total_records);
}

async fn validate(mut self) -> Result<(), Error> {
async fn validate_indexed(mut self, batch_index: usize) -> Result<(), Error> {
let arc = self
.inner_ref
.take()
Expand All @@ -793,7 +856,7 @@ impl<'a, B: ShardBinding> DZKPValidator for MaliciousDZKPValidator<'a, B> {

batcher
.into_single_batch()
.validate(validate_ctx.narrow(&DzkpSingleBatchStep))
.validate(validate_ctx, batch_index)
.await
}

Expand Down Expand Up @@ -1262,16 +1325,12 @@ mod tests {
}

proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]
fn batching_proptest((record_count, max_multiplications_per_gate) in batching()) {
println!("record_count {record_count} batch {max_multiplications_per_gate}");
if record_count / max_multiplications_per_gate >= 192 {
// TODO: #1269, or even if we don't fix that, don't hardcode the limit.
println!("skipping config because batch count exceeds limit of 192");
}
// This condition is correct only for active_work = 16 and record size of 1 byte.
else if max_multiplications_per_gate != 1 && max_multiplications_per_gate % 16 != 0 {
if max_multiplications_per_gate != 1 && max_multiplications_per_gate % 16 != 0 {
// TODO: #1300, read_size | batch_size.
// Note: for active work < 2048, read size matches active work.

Expand Down
12 changes: 0 additions & 12 deletions ipa-core/src/protocol/context/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,6 @@ pub(crate) enum ValidateStep {
CheckZero,
}

// This really is only for DZKPs and not for MACs. The MAC protocol uses record IDs to
// count batches. DZKP probably should do the same to avoid the fixed upper limit.
#[derive(CompactStep)]
#[step(count = 600, child = DzkpValidationProtocolStep)]
pub(crate) struct DzkpBatchStep(pub usize);

// This is used when we don't do batched verification, to avoid paying for x256 as many
// steps in compact gate.
#[derive(CompactStep)]
#[step(child = DzkpValidationProtocolStep)]
pub(crate) struct DzkpSingleBatchStep;

#[derive(CompactStep)]
pub(crate) enum DzkpValidationProtocolStep {
/// Step for proof generation
Expand Down
4 changes: 2 additions & 2 deletions ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ where
let validator = ctx.clone().dzkp_validator(
MaliciousProtocolSteps {
protocol: &Step::aggregate(depth),
validate: &Step::aggregate_validate(chunk_counter),
validate: &Step::AggregateValidate,
},
// We have to specify usize::MAX here because the procession through
// record IDs is different at each step of the reduction. The batch
Expand All @@ -119,7 +119,7 @@ where
Some(&mut record_ids),
)
.await?;
validator.validate().await?;
validator.validate_indexed(chunk_counter).await?;
chunk_counter += 1;
next_intermediate_results.push(result);
}
Expand Down
10 changes: 5 additions & 5 deletions ipa-core/src/protocol/ipa_prf/aggregation/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ pub(crate) enum AggregationStep {
#[step(child = crate::protocol::ipa_prf::shuffle::step::OPRFShuffleStep)]
Shuffle,
Reveal,
#[step(child = crate::protocol::context::step::DzkpSingleBatchStep)]
#[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)]
RevealValidate, // only partly used -- see code
#[step(count = 4, child = AggregateChunkStep)]
#[step(count = 4, child = AggregateChunkStep, name = "chunks")]
Aggregate(usize),
#[step(count = 600, child = crate::protocol::context::step::DzkpSingleBatchStep)]
AggregateValidate(usize),
#[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)]
AggregateValidate,
}

// The step count here is duplicated as the AGGREGATE_DEPTH constant in the code.
#[derive(CompactStep)]
#[step(count = 24, child = AggregateValuesStep, name = "depth")]
#[step(count = 24, child = AggregateValuesStep, name = "fold")]
pub(crate) struct AggregateChunkStep(usize);

#[derive(CompactStep)]
Expand Down
10 changes: 10 additions & 0 deletions ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@
}
}

impl<F, const N: usize> Default for CanonicalLagrangeDenominator<F, N>
where
F: PrimeField + TryFrom<u128>,
<F as TryFrom<u128>>::Error: Debug,
{
fn default() -> Self {
Self::new()
}

Check warning on line 60 in ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs#L58-L60

Added lines #L58 - L60 were not covered by tests
}

/// `LagrangeTable` is a precomputed table for the Lagrange evaluation.
/// Allows to compute points on the polynomial, i.e. output points,
/// given enough points on the polynomial, i.e. input points,
Expand Down
Loading