Skip to content

Commit

Permalink
Table domain separation (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
chancharles92 authored Apr 7, 2022
1 parent 1fc9a48 commit 5eb3059
Show file tree
Hide file tree
Showing 10 changed files with 298 additions and 78 deletions.
146 changes: 100 additions & 46 deletions plonk/src/circuit/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ use rayon::prelude::*;

/// The wire type identifier for range gates.
const RANGE_WIRE_ID: usize = 5;
/// The wire type identifier for the key index in a lookup gate
const LOOKUP_KEY_WIRE_ID: usize = 0;
/// The wire type identifiers for the searched pair values in a lookup gate
const LOOKUP_VAL_1_WIRE_ID: usize = 1;
const LOOKUP_VAL_2_WIRE_ID: usize = 2;
/// The wire type identifiers for the pair values in the lookup table
const TABLE_VAL_1_WIRE_ID: usize = 3;
const TABLE_VAL_2_WIRE_ID: usize = 4;

/// Hardcoded parameters for Plonk systems.
#[derive(Debug, Clone, Copy)]
Expand Down Expand Up @@ -314,29 +322,38 @@ impl<F: FftField> Circuit<F> for PlonkCircuit<F> {
}
// key-value map lookup gates
let mut key_val_table = HashSet::new();
key_val_table.insert((F::zero(), F::zero(), F::zero()));
let mut num_table_elems: u32 = 0;
key_val_table.insert((F::zero(), F::zero(), F::zero(), F::zero()));
let q_lookup_vec = self.q_lookup();
for (gate_id, &q_lookup) in q_lookup_vec.iter().enumerate() {
let q_dom_sep_vec = self.q_dom_sep();
let table_key_vec = self.table_key_vec();
let table_dom_sep_vec = self.table_dom_sep_vec();
// insert table elements
for (gate_id, ((&q_lookup, &table_dom_sep), &table_key)) in q_lookup_vec
.iter()
.zip(table_dom_sep_vec.iter())
.zip(table_key_vec.iter())
.enumerate()
{
if q_lookup != F::zero() {
let key = F::from(num_table_elems);
let val0 = self.witness(self.wire_variable(3, gate_id))?;
let val1 = self.witness(self.wire_variable(4, gate_id))?;
key_val_table.insert((key, val0, val1));
num_table_elems += 1;
let val0 = self.witness(self.wire_variable(TABLE_VAL_1_WIRE_ID, gate_id))?;
let val1 = self.witness(self.wire_variable(TABLE_VAL_2_WIRE_ID, gate_id))?;
key_val_table.insert((table_dom_sep, table_key, val0, val1));
}
}
for (gate_id, &q_lookup) in q_lookup_vec.iter().enumerate() {
// check lookups
for (gate_id, (&q_lookup, &q_dom_sep)) in
q_lookup_vec.iter().zip(q_dom_sep_vec.iter()).enumerate()
{
if q_lookup != F::zero() {
let key = self.witness(self.wire_variable(0, gate_id))?;
let val0 = self.witness(self.wire_variable(1, gate_id))?;
let val1 = self.witness(self.wire_variable(2, gate_id))?;
if !key_val_table.contains(&(key, val0, val1)) {
let key = self.witness(self.wire_variable(LOOKUP_KEY_WIRE_ID, gate_id))?;
let val0 = self.witness(self.wire_variable(LOOKUP_VAL_1_WIRE_ID, gate_id))?;
let val1 = self.witness(self.wire_variable(LOOKUP_VAL_2_WIRE_ID, gate_id))?;
if !key_val_table.contains(&(q_dom_sep, key, val0, val1)) {
return Err(GateCheckFailure(
gate_id,
format!(
"Lookup gate failed: ({}, {}, {}) not in the table",
key, val0, val1
"Lookup gate failed: ({}, {}, {}, {}) not in the table",
q_dom_sep, key, val0, val1
),
)
.into());
Expand Down Expand Up @@ -788,6 +805,21 @@ impl<F: FftField> PlonkCircuit<F> {
fn q_lookup(&self) -> Vec<F> {
self.gates.iter().map(|g| g.q_lookup()).collect()
}
// getter for all lookup domain separation selector
#[inline]
fn q_dom_sep(&self) -> Vec<F> {
self.gates.iter().map(|g| g.q_dom_sep()).collect()
}
// getter for the vector of table keys
#[inline]
fn table_key_vec(&self) -> Vec<F> {
self.gates.iter().map(|g| g.table_key()).collect()
}
// getter for the vector of table domain separation ids
#[inline]
fn table_dom_sep_vec(&self) -> Vec<F> {
self.gates.iter().map(|g| g.table_dom_sep()).collect()
}
// TODO: (alex) try return reference instead of expensive clone
// getter for all selectors in the following order:
// q_lc, q_mul, q_hash, q_o, q_c, q_ecc, [q_lookup (if support lookup)]
Expand Down Expand Up @@ -1171,24 +1203,45 @@ where
}

fn compute_key_table_polynomial(&self) -> Result<DensePolynomial<F>, PlonkError> {
let key_table = self.compute_key_table()?;
self.check_plonk_type(PlonkType::UltraPlonk)?;
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
Ok(DensePolynomial::from_coefficients_vec(
domain.ifft(&self.table_key_vec()),
))
}

fn compute_table_dom_sep_polynomial(&self) -> Result<DensePolynomial<F>, PlonkError> {
self.check_plonk_type(PlonkType::UltraPlonk)?;
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
Ok(DensePolynomial::from_coefficients_vec(
domain.ifft(&self.table_dom_sep_vec()),
))
}

fn compute_q_dom_sep_polynomial(&self) -> Result<DensePolynomial<F>, PlonkError> {
self.check_plonk_type(PlonkType::UltraPlonk)?;
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
Ok(DensePolynomial::from_coefficients_vec(
domain.ifft(&key_table),
domain.ifft(&self.q_dom_sep()),
))
}

fn compute_merged_lookup_table(&self, tau: F) -> Result<Vec<F>, PlonkError> {
let range_table = self.compute_range_table()?;
let key_table = self.compute_key_table()?;
let table_key_vec = self.table_key_vec();
let table_dom_sep_vec = self.table_dom_sep_vec();
let q_lookup_vec = self.q_lookup();

let mut merged_lookup_table = vec![];
for i in 0..self.eval_domain_size()? {
merged_lookup_table.push(self.merged_table_value(
tau,
&range_table,
&key_table,
&table_key_vec,
&table_dom_sep_vec,
&q_lookup_vec,
i,
)?);
Expand Down Expand Up @@ -1230,9 +1283,11 @@ where
let beta_plus_one = F::one() + *beta;
let gamma_mul_beta_plus_one = *gamma * beta_plus_one;
let q_lookup_vec = self.q_lookup();
let q_dom_sep_vec = self.q_dom_sep();
for j in 0..(n - 2) {
// compute merged lookup witness value
let lookup_wire_val = self.merged_lookup_wire_value(*tau, j, &q_lookup_vec)?;
let lookup_wire_val =
self.merged_lookup_wire_value(*tau, j, &q_lookup_vec, &q_dom_sep_vec)?;
let table_val = merged_lookup_table[j];
let table_next_val = merged_lookup_table[j + 1];
let h1_val = sorted_vec[j];
Expand Down Expand Up @@ -1281,8 +1336,9 @@ where
// only the first n-1 variables are for lookup
let mut lookup_map = HashMap::<F, usize>::new();
let q_lookup_vec = self.q_lookup();
let q_dom_sep_vec = self.q_dom_sep();
for i in 0..(n - 1) {
let elem = self.merged_lookup_wire_value(tau, i, &q_lookup_vec)?;
let elem = self.merged_lookup_wire_value(tau, i, &q_lookup_vec, &q_dom_sep_vec)?;
let n_lookups = lookup_map.entry(elem).or_insert(0);
*n_lookups += 1;
}
Expand Down Expand Up @@ -1328,35 +1384,26 @@ impl<F: PrimeField> PlonkCircuit<F> {
Ok(range_table)
}

#[inline]
// TODO: generalize to arbitrary key sets.
fn compute_key_table(&self) -> Result<Vec<F>, PlonkError> {
self.check_plonk_type(PlonkType::UltraPlonk)?;
self.check_finalize_flag(true)?;
let n = self.eval_domain_size()?;
let mut key_table = vec![F::zero(); n - 1 - self.num_table_elems];
for i in 0..self.num_table_elems {
key_table.push(F::from(i as u32));
}
key_table.push(F::zero());
Ok(key_table)
}

#[inline]
fn merged_table_value(
&self,
tau: F,
range_table: &[F],
key_table: &[F],
table_key_vec: &[F],
table_dom_sep_vec: &[F],
q_lookup_vec: &[F],
i: usize,
) -> Result<F, PlonkError> {
let range_val = range_table[i];
let key_val = key_table[i];
let key_val = table_key_vec[i];
let dom_sep_val = table_dom_sep_vec[i];
let q_lookup_val = q_lookup_vec[i];
let w3_val = self.witness(self.wire_variable(3, i))?;
let w4_val = self.witness(self.wire_variable(4, i))?;
Ok(range_val + q_lookup_val * tau * (key_val + tau * (w3_val + tau * w4_val)))
let table_val_1 = self.witness(self.wire_variable(TABLE_VAL_1_WIRE_ID, i))?;
let table_val_2 = self.witness(self.wire_variable(TABLE_VAL_2_WIRE_ID, i))?;
Ok(range_val
+ q_lookup_val
* tau
* (dom_sep_val + tau * (key_val + tau * (table_val_1 + tau * table_val_2))))
}

#[inline]
Expand All @@ -1365,13 +1412,18 @@ impl<F: PrimeField> PlonkCircuit<F> {
tau: F,
i: usize,
q_lookup_vec: &[F],
q_dom_sep_vec: &[F],
) -> Result<F, PlonkError> {
let w_range_val = self.witness(self.wire_variable(RANGE_WIRE_ID, i))?;
let w_0_val = self.witness(self.wire_variable(0, i))?;
let w_1_val = self.witness(self.wire_variable(1, i))?;
let w_2_val = self.witness(self.wire_variable(2, i))?;
let lookup_key = self.witness(self.wire_variable(LOOKUP_KEY_WIRE_ID, i))?;
let lookup_val_1 = self.witness(self.wire_variable(LOOKUP_VAL_1_WIRE_ID, i))?;
let lookup_val_2 = self.witness(self.wire_variable(LOOKUP_VAL_2_WIRE_ID, i))?;
let q_lookup_val = q_lookup_vec[i];
Ok(w_range_val + q_lookup_val * tau * (w_0_val + tau * (w_1_val + tau * w_2_val)))
let q_dom_sep_val = q_dom_sep_vec[i];
Ok(w_range_val
+ q_lookup_val
* tau
* (q_dom_sep_val + tau * (lookup_key + tau * (lookup_val_1 + tau * lookup_val_2))))
}
}

Expand Down Expand Up @@ -1969,7 +2021,7 @@ pub(crate) mod test {

// Check key table polynomial
let key_table_poly = circuit.compute_key_table_polynomial()?;
let key_table = circuit.compute_key_table()?;
let key_table = circuit.table_key_vec();
check_polynomial(&key_table_poly, &key_table);

// Check sorted vector polynomials
Expand Down Expand Up @@ -2016,8 +2068,10 @@ pub(crate) mod test {
let one_plus_beta = F::one() + beta;
let gamma_mul_one_plus_beta = gamma * one_plus_beta;
let q_lookup_vec = circuit.q_lookup();
let q_dom_sep = circuit.q_dom_sep();
for j in 0..(n - 2) {
let lookup_wire_val = circuit.merged_lookup_wire_value(tau, j, &q_lookup_vec)?;
let lookup_wire_val =
circuit.merged_lookup_wire_value(tau, j, &q_lookup_vec, &q_dom_sep)?;
let table_val = merged_lookup_table[j];
let table_next_val = merged_lookup_table[j + 1];
let h1_val = sorted_vec[j];
Expand Down
3 changes: 0 additions & 3 deletions plonk/src/circuit/customized/ecc/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,6 @@ where

// create circuit
let range_size = F::from((1 << c) as u32);
for &var in decomposed_scalar_vars.iter() {
circuit.range_gate(var, c)?;
}
circuit.decompose_vars_gate(decomposed_scalar_vars.clone(), scalar_var, range_size)?;

Ok(decomposed_scalar_vars)
Expand Down
17 changes: 15 additions & 2 deletions plonk/src/circuit/customized/gates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,13 @@ where

/// An UltraPlonk lookup gate
#[derive(Debug, Clone)]
pub struct LookupGate;
pub struct LookupGate<F: Field> {
pub(crate) q_dom_sep: F,
pub(crate) table_dom_sep: F,
pub(crate) table_key: F,
}

impl<F> Gate<F> for LookupGate
impl<F> Gate<F> for LookupGate<F>
where
F: Field,
{
Expand All @@ -317,4 +321,13 @@ where
fn q_lookup(&self) -> F {
F::one()
}
fn q_dom_sep(&self) -> F {
self.q_dom_sep
}
fn table_key(&self) -> F {
self.table_key
}
fn table_dom_sep(&self) -> F {
self.table_dom_sep
}
}
73 changes: 51 additions & 22 deletions plonk/src/circuit/customized/ultraplonk/lookup_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,14 @@ use crate::{
errors::PlonkError,
};
use ark_ff::PrimeField;
use ark_std::{boxed::Box, cmp::max, vec::Vec};
use ark_std::{boxed::Box, cmp::max};

impl<F: PrimeField> PlonkCircuit<F> {
/// Create a table with keys/values
/// [table_id, ..., table_id + n - 1] and
/// [table_vars\[0\], ..., table_vars[n - 1]];
/// [0, ..., n - 1] and
/// [table_vars\[0\], ..., table_vars\[n - 1\]];
/// and create a list of variable tuples to be looked up:
/// [lookup_vars\[0\], ..., lookup_vars[m - 1]];
///
/// **For each variable tuple `(lookup_var.0, lookup_var.1, lookup_var.2)`
/// to be looked up, the index variable `lookup_var.0` is required to be
/// in range [0, n) (either constrained by a range-check gate or other
/// circuits), so that one can't set it out of bounds and thus do a
/// lookup into one of the *other* tables. **
/// [lookup_vars\[0\], ..., lookup_vars\[m - 1\]];
///
/// w.l.o.g we assume n = m as we can pad with dummy tuples when n != m
pub fn create_table_and_lookup_variables(
Expand All @@ -42,24 +36,38 @@ impl<F: PrimeField> PlonkCircuit<F> {
self.check_var_bound(table_var.1)?;
}
let n = max(lookup_vars.len(), table_vars.len());
// update lookup keys for domain separation.
let lookup_keys: Vec<Variable> = lookup_vars
.iter()
.map(|&(key, ..)| self.add_constant(key, &F::from(self.num_table_elems() as u32)))
.collect::<Result<Vec<_>, _>>()?;
let n_gate = self.num_gates();
(*self.table_gate_ids_mut()).push((n_gate, n));
let table_ctr = F::from(self.table_gate_ids_mut().len() as u64);
for i in 0..n {
let (key, val0, val1) = match i < lookup_vars.len() {
true => (lookup_keys[i], lookup_vars[i].1, lookup_vars[i].2),
false => (self.zero(), self.zero(), self.zero()),
let (q_dom_sep, key, val0, val1) = match i < lookup_vars.len() {
true => (
table_ctr,
lookup_vars[i].0,
lookup_vars[i].1,
lookup_vars[i].2,
),
false => (F::zero(), self.zero(), self.zero(), self.zero()),
};
let (table_val0, table_val1) = match i < table_vars.len() {
true => table_vars[i],
false => (self.zero(), self.zero()),
let (table_dom_sep, table_key, table_val0, table_val1) = match i < table_vars.len() {
true => (
table_ctr,
F::from(i as u64),
table_vars[i].0,
table_vars[i].1,
),
false => (F::zero(), F::zero(), self.zero(), self.zero()),
};
let wire_vars = [key, val0, val1, table_val0, table_val1];
self.insert_gate(&wire_vars, Box::new(LookupGate))?;

self.insert_gate(
&wire_vars,
Box::new(LookupGate {
q_dom_sep,
table_dom_sep,
table_key,
}),
)?;
}
*self.num_table_elems_mut() += n;
Ok(())
Expand Down Expand Up @@ -136,6 +144,27 @@ mod test {
.create_table_and_lookup_variables(&lookup_vars, &bad_table_vars)
.is_err());

// A lookup over a separate table should not satisfy the circuit.
let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_ultra_plonk(4);
let mut rng = test_rng();

let val0 = circuit.create_variable(F::rand(&mut rng))?;
let val1 = circuit.create_variable(F::rand(&mut rng))?;
let table_vars_1 = vec![(val0, val1)];
let val2 = circuit.create_variable(F::rand(&mut rng))?;
let val3 = circuit.create_variable(F::rand(&mut rng))?;
let table_vars_2 = vec![(val2, val3)];
let val2 = circuit.witness(table_vars_2[0].0)?;
let val3 = circuit.witness(table_vars_2[0].1)?;
let val2_var = circuit.create_variable(val2)?;
let val3_var = circuit.create_variable(val3)?;
let lookup_vars_1 = vec![(circuit.zero(), val2_var, val3_var)];

circuit.create_table_and_lookup_variables(&lookup_vars_1, &table_vars_2)?;
assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
circuit.create_table_and_lookup_variables(&lookup_vars_1, &table_vars_1)?;
assert!(circuit.check_circuit_satisfiability(&[]).is_err());

Ok(())
}
}
Loading

0 comments on commit 5eb3059

Please sign in to comment.