Skip to content

Commit

Permalink
Evaluate each f_uji polynomial recursively
Browse files Browse the repository at this point in the history
  • Loading branch information
cargodog committed Jan 26, 2021
1 parent a2ec5b9 commit b7415c2
Showing 1 changed file with 98 additions and 50 deletions.
148 changes: 98 additions & 50 deletions src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,8 @@ impl ArcturusGens {
x_p.push(t.challenge_scalar(b"x"));
}

let u_max = proofs.iter().map(|p| p.f_uji.len()).max().unwrap();

// To efficiently verify a proof in a single multiexponentiation, each
// inequality in the protocol is joined into one large inequality. To
// prevent an attacker from selecting terms that break the
Expand All @@ -562,43 +564,33 @@ impl ArcturusGens {
.map(|_| (0..5).map(|_| Scalar::random(&mut rng)).collect())
.collect::<Vec<Vec<_>>>();

// Compute each f_uj0 from each proof's f_uji (where i = [1..n])
let f_puj0 = proofs
.iter()
.zip(x_p.iter())
// Join f_puj0 & each of p.f_uji to create f_puji
let f_puji = izip!(proofs, &x_p)
.map(|(p, x)| {
p.f_uji
.iter()
.map(|f_ji| {
f_ji.iter()
.map(|f_i| x - f_i.iter().sum::<Scalar>())
.collect()
.map(|f_i| {
once(x - f_i.iter().sum::<Scalar>())
.chain(f_i.iter().cloned())
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
})
.collect()
.collect::<Vec<_>>()
})
.collect::<Vec<Vec<Vec<_>>>>();

// Join f_puj0 & each of p.f_uji to create f_puji
let f_puji = f_puj0.iter().zip(proofs.iter()).map(|(f_uj0, p)| {
f_uj0.iter().zip(p.f_uji.iter()).map(|(f_j0, f_ji)| {
f_j0.iter()
.zip(f_ji.iter())
.map(|(f_0, f_i)| once(f_0).chain(f_i.iter()))
})
});

let u_max = proofs.iter().map(|p| p.f_uji.len()).max().unwrap();
.collect::<Vec<_>>();

// Ring coefficients computed from f_uji from each proof
let f_poly_pk = proofs
// Evaluate each tensor polynomial
let f_poly_pk = f_puji
.iter()
.zip(x_p.iter())
.map(move |(p, x)| {
(0..self.ring_size())
.map(move |k| compute_f_coeff(k, x, self.n, self.m, &p.f_uji))
.collect()
.map(|f_uji| {
cycle_tensor_poly_evals(&f_uji)
.take(self.ring_size())
.collect::<Vec<_>>()
})
.collect::<Vec<Vec<_>>>();
.collect::<Vec<_>>();

// Each of equations (1), (2), (3), (4), & (5) are comprised of terms of point
// multiplications. Below we collect each point to be multiplied, and compute the
Expand Down Expand Up @@ -675,7 +667,12 @@ impl ArcturusGens {
let scalars_H_uji = (0..self.n * self.m * u_max).map(|l| {
izip!(f_puji.clone(), x_p.iter(), wt_pn.iter())
.map(|(f_uji, x, wt_n)| {
let f = f_uji.flatten().flatten().nth(l).unwrap();
let f = f_uji
.into_iter()
.flat_map(|f_ji| f_ji.into_iter())
.flat_map(|f_i| f_i.into_iter())
.nth(l)
.unwrap();
wt_n[0] * f + wt_n[1] * f * (x - f) // Combination of terms from equations (1) & (2)
})
.sum::<Scalar>()
Expand Down Expand Up @@ -840,29 +837,80 @@ fn convert_base(num: usize, base: usize, digits: usize) -> Vec<usize> {
x_j
}

#[inline]
fn compute_f_coeff(
k: usize,
x: &Scalar,
base: usize,
digits: usize,
f_uji: &Vec<Vec<Vec<Scalar>>>,
) -> Scalar {
let k = convert_base(k, base, digits);
f_uji
.iter()
.map(|f_ji| {
(0..digits)
.map(|j| {
if k[j] == 0 {
x - f_ji[j].iter().sum::<Scalar>()
} else {
f_ji[j][k[j] - 1]
struct CycleTensorPolyEvals<'a> {
w: usize,
m: usize,
n: usize,
f_uji: &'a Vec<Vec<Vec<Scalar>>>,
partial_digits_j: Vec<usize>,
partial_prods_ju: Vec<Vec<Scalar>>,
}

fn cycle_tensor_poly_evals<'a>(f_uji: &'a Vec<Vec<Vec<Scalar>>>) -> CycleTensorPolyEvals<'a> {
let w = f_uji.len();
let m = f_uji[0].len();
let n = f_uji[0][0].len();
let partial_digits_j = Vec::with_capacity(m);
let partial_prods_ju = vec![vec![Scalar::one(); w]; m];
CycleTensorPolyEvals {
w,
m,
n,
f_uji,
partial_digits_j,
partial_prods_ju,
}
}

impl CycleTensorPolyEvals<'_> {
// Recursively multiply the factors corresponding to each digit of k
fn update(&mut self) {
if let Some(mut i) = self.partial_digits_j.pop() {
i += 1;
if i >= self.n {
self.update();
} else {
let j = self.partial_digits_j.len();
for u in 0..self.w {
self.partial_prods_ju[j][u] = self.f_uji[u][self.m - j - 1][i];
}
if j > 0 {
for u in 0..self.w {
let prev_part = self.partial_prods_ju[j - 1][u];
self.partial_prods_ju[j][u] *= prev_part;
}
})
.product::<Scalar>()
})
.sum::<Scalar>()
}
self.partial_digits_j.push(i);
}
}
for j in self.partial_digits_j.len()..self.m {
self.partial_digits_j.push(0);
for u in 0..self.w {
self.partial_prods_ju[j][u] = self.f_uji[u][self.m - j - 1][0];
}
if j > 0 {
for u in 0..self.w {
let prev_part = self.partial_prods_ju[j - 1][u];
self.partial_prods_ju[j][u] *= prev_part;
}
}
}
}
}

impl Iterator for CycleTensorPolyEvals<'_> {
type Item = Scalar;

#[inline]
fn next(&mut self) -> Option<Scalar> {
self.update();
Some(self.partial_prods_ju.last().unwrap().iter().sum::<Scalar>())
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(usize::max_value(), None)
}
}

fn compute_p_j(k: usize, l: usize, a_ji: &Vec<Vec<Scalar>>) -> Vec<Scalar> {
Expand Down

0 comments on commit b7415c2

Please sign in to comment.