Skip to content

Commit

Permalink
add and improve documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-ferdinand committed Feb 17, 2023
1 parent 6780136 commit 9386878
Showing 1 changed file with 94 additions and 77 deletions.
171 changes: 94 additions & 77 deletions triton-vm/src/table/constraint_circuit.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
//! Constraint circuits are a way to represent constraint polynomials in a way that is amenable
//! to optimizations. The constraint circuit is a directed acyclic graph (DAG) of
//! [`CircuitExpression`]s, where each `CircuitExpression` is a node in the graph. The edges of the
//! graph are labeled with [`BinOp`]s. The leafs of the graph are the inputs to the constraint
//! polynomial, and the (multiple) roots of the graph are the outputs of all the
//! constraint polynomials, with each root corresponding to a different constraint polynomial.
//! Because the graph has multiple roots, it is called a “multitree.”
use std::borrow::BorrowMut;
use std::cell::RefCell;
use std::cmp;
Expand Down Expand Up @@ -42,19 +50,17 @@ impl Display for BinOp {
}
}

/// An `InputIndicator` is a type that describes the position of a variable in
/// a constraint polynomial in the row layout applicable for a certain kind of
/// constraint polynomial.
///
/// A variable in a constraint polynomial comes in the shape of a `usize`, but
/// depending on the type of constraint polynomial, this index may be an index
/// into a single row (for initial, consistency and terminal constraints), or
/// a pair of adjacent rows (for transition constraints), or some other layout
/// for a third type of constraint.
/// Describes the position of a variable in a constraint polynomial in the row layout applicable
/// for a certain kind of constraint polynomial.
///
/// `From<usize>` and `Into<usize>` occur for the purpose of this conversion.
/// The position of variable in a constraint polynomial is, in principle, a `usize`. However,
/// depending on the type of the constraint polynomial, this index may be an index into a single
/// row (for initial, consistency and terminal constraints), or a pair of adjacent rows (for
/// transition constraints). Additionally, the index may refer to a column in the base table, or
/// a column in the extension table. This trait abstracts over these possibilities, and provides
/// a uniform interface for accessing the index.
///
/// Having `Clone + Copy + Hash + PartialEq + Eq` help put these in containers.
/// Having `Clone + Copy + Hash + PartialEq + Eq` helps putting `InputIndicator`s into containers.
pub trait InputIndicator: Debug + Clone + Copy + Hash + PartialEq + Eq + Display {
/// `true` iff `self` refers to a column in the base table.
fn is_base_table_row(&self) -> bool;
Expand All @@ -69,9 +75,8 @@ pub trait InputIndicator: Debug + Clone + Copy + Hash + PartialEq + Eq + Display
) -> XFieldElement;
}

/// A `SingleRowIndicator<BASE_COLUMN_COUNT, EXT_COLUMN_COUNT>` describes the position of a variable in
/// a constraint polynomial that operates on a single execution trace table at a
/// time.
/// The position of a variable in a constraint polynomial that operates on a single row of the
/// execution trace.
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub enum SingleRowIndicator<const BASE_COLUMN_COUNT: usize, const EXT_COLUMN_COUNT: usize> {
BaseRow(usize),
Expand All @@ -82,9 +87,10 @@ impl<const BASE_COLUMN_COUNT: usize, const EXT_COLUMN_COUNT: usize> Display
for SingleRowIndicator<BASE_COLUMN_COUNT, EXT_COLUMN_COUNT>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use SingleRowIndicator::*;
let input_indicator: String = match self {
SingleRowIndicator::BaseRow(i) => format!("base_row[{i}]"),
SingleRowIndicator::ExtRow(i) => format!("ext_row[{i}]"),
BaseRow(i) => format!("base_row[{i}]"),
ExtRow(i) => format!("ext_row[{i}]"),
};

writeln!(f, "{input_indicator}")
Expand All @@ -95,23 +101,23 @@ impl<const BASE_COLUMN_COUNT: usize, const EXT_COLUMN_COUNT: usize> InputIndicat
for SingleRowIndicator<BASE_COLUMN_COUNT, EXT_COLUMN_COUNT>
{
fn is_base_table_row(&self) -> bool {
match self {
SingleRowIndicator::BaseRow(_) => true,
SingleRowIndicator::ExtRow(_) => false,
}
use SingleRowIndicator::*;
matches!(self, BaseRow(_))
}

fn base_row_index(&self) -> usize {
use SingleRowIndicator::*;
match self {
SingleRowIndicator::BaseRow(i) => *i,
SingleRowIndicator::ExtRow(_) => panic!("not a base row"),
BaseRow(i) => *i,
ExtRow(_) => panic!("not a base row"),
}
}

fn ext_row_index(&self) -> usize {
use SingleRowIndicator::*;
match self {
SingleRowIndicator::BaseRow(_) => panic!("not an ext row"),
SingleRowIndicator::ExtRow(i) => *i,
BaseRow(_) => panic!("not an ext row"),
ExtRow(i) => *i,
}
}

Expand All @@ -120,15 +126,16 @@ impl<const BASE_COLUMN_COUNT: usize, const EXT_COLUMN_COUNT: usize> InputIndicat
base_table: ArrayView2<BFieldElement>,
ext_table: ArrayView2<XFieldElement>,
) -> XFieldElement {
use SingleRowIndicator::*;
match self {
SingleRowIndicator::BaseRow(i) => base_table[[0, *i]].lift(),
SingleRowIndicator::ExtRow(i) => ext_table[[0, *i]],
BaseRow(i) => base_table[[0, *i]].lift(),
ExtRow(i) => ext_table[[0, *i]],
}
}
}

/// A `DualRowIndicator<BASE_COLUMN_COUNT, EXT_COLUMN_COUNT>` describes the position of a variable in
/// a constraint polynomial that operates on pairs of rows (current and next).
/// The position of a variable in a constraint polynomial that operates on two rows (current and
/// next) of the execution trace.
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub enum DualRowIndicator<const BASE_COLUMN_COUNT: usize, const EXT_COLUMN_COUNT: usize> {
CurrentBaseRow(usize),
Expand All @@ -141,11 +148,12 @@ impl<const BASE_COLUMN_COUNT: usize, const EXT_COLUMN_COUNT: usize> Display
for DualRowIndicator<BASE_COLUMN_COUNT, EXT_COLUMN_COUNT>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use DualRowIndicator::*;
let input_indicator: String = match self {
DualRowIndicator::CurrentBaseRow(i) => format!("current_base_row[{i}]"),
DualRowIndicator::CurrentExtRow(i) => format!("current_ext_row[{i}]"),
DualRowIndicator::NextBaseRow(i) => format!("next_base_row[{i}]"),
DualRowIndicator::NextExtRow(i) => format!("next_ext_row[{i}]"),
CurrentBaseRow(i) => format!("current_base_row[{i}]"),
CurrentExtRow(i) => format!("current_ext_row[{i}]"),
NextBaseRow(i) => format!("next_base_row[{i}]"),
NextExtRow(i) => format!("next_ext_row[{i}]"),
};

writeln!(f, "{input_indicator}")
Expand All @@ -156,29 +164,23 @@ impl<const BASE_COLUMN_COUNT: usize, const EXT_COLUMN_COUNT: usize> InputIndicat
for DualRowIndicator<BASE_COLUMN_COUNT, EXT_COLUMN_COUNT>
{
fn is_base_table_row(&self) -> bool {
match self {
DualRowIndicator::CurrentBaseRow(_) => true,
DualRowIndicator::CurrentExtRow(_) => false,
DualRowIndicator::NextBaseRow(_) => true,
DualRowIndicator::NextExtRow(_) => false,
}
use DualRowIndicator::*;
matches!(self, CurrentBaseRow(_) | NextBaseRow(_))
}

fn base_row_index(&self) -> usize {
use DualRowIndicator::*;
match self {
DualRowIndicator::CurrentBaseRow(i) => *i,
DualRowIndicator::CurrentExtRow(_) => panic!("not a base row"),
DualRowIndicator::NextBaseRow(i) => *i,
DualRowIndicator::NextExtRow(_) => panic!("not a base row"),
CurrentBaseRow(i) | NextBaseRow(i) => *i,
CurrentExtRow(_) | NextExtRow(_) => panic!("not a base row"),
}
}

fn ext_row_index(&self) -> usize {
use DualRowIndicator::*;
match self {
DualRowIndicator::CurrentBaseRow(_) => panic!("not an ext row"),
DualRowIndicator::CurrentExtRow(i) => *i,
DualRowIndicator::NextBaseRow(_) => panic!("not an ext row"),
DualRowIndicator::NextExtRow(i) => *i,
CurrentBaseRow(_) | NextBaseRow(_) => panic!("not an ext row"),
CurrentExtRow(i) | NextExtRow(i) => *i,
}
}

Expand All @@ -187,15 +189,32 @@ impl<const BASE_COLUMN_COUNT: usize, const EXT_COLUMN_COUNT: usize> InputIndicat
base_table: ArrayView2<BFieldElement>,
ext_table: ArrayView2<XFieldElement>,
) -> XFieldElement {
use DualRowIndicator::*;
match self {
DualRowIndicator::CurrentBaseRow(i) => base_table[[0, *i]].lift(),
DualRowIndicator::CurrentExtRow(i) => ext_table[[0, *i]],
DualRowIndicator::NextBaseRow(i) => base_table[[1, *i]].lift(),
DualRowIndicator::NextExtRow(i) => ext_table[[1, *i]],
CurrentBaseRow(i) => base_table[[0, *i]].lift(),
CurrentExtRow(i) => ext_table[[0, *i]],
NextBaseRow(i) => base_table[[1, *i]].lift(),
NextExtRow(i) => ext_table[[1, *i]],
}
}
}

/// A circuit expression is the recursive data structure that represents the constraint polynomials.
/// It is a directed, acyclic graph of binary operations on the variables of the constraint
/// polynomials, constants, and challenges. It has multiple roots, making it a “multitree.” Each
/// root corresponds to one constraint polynomial.
///
/// The leafs of the tree are
/// - constants in the base field, _i.e._, [`BFieldElement`]s,
/// - constants in the extension field, _i.e._, [`XFieldElement`]s,
/// - input variables, _i.e._, entries from the Algebraic Execution Trace, and
/// - challenges, _i.e._, (pseudo-)random values sampled through the Fiat-Shamir heuristic.
///
/// An inner node, representing some binary operation, is either addition, multiplication, or
/// subtraction. The left and right children of the node are the operands of the binary operation.
/// The left and right children are not themselves `CircuitExpression`s, but rather
/// [`ConstraintCircuit`]s, which is a wrapper around `CircuitExpression` that manages additional
/// bookkeeping information.
#[derive(Debug, Clone)]
pub enum CircuitExpression<II: InputIndicator> {
XConstant(XFieldElement),
Expand Down Expand Up @@ -239,6 +258,7 @@ impl<II: InputIndicator> Hash for ConstraintCircuitMonad<II> {
}
}

/// A wrapper around a [`CircuitExpression`] that manages additional bookkeeping information.
#[derive(Clone, Debug)]
pub struct ConstraintCircuit<II: InputIndicator> {
pub id: usize,
Expand All @@ -249,11 +269,9 @@ pub struct ConstraintCircuit<II: InputIndicator> {
impl<II: InputIndicator> Eq for ConstraintCircuit<II> {}

impl<II: InputIndicator> PartialEq for ConstraintCircuit<II> {
/// Calculate equality of circuits.
/// In particular, this function does *not* attempt to simplify
/// or reduce neutral terms or products. So this comparison will
/// return false for `a == a + 0`. It will also return false for
/// `XFieldElement(7) == BFieldElement(7)`
/// Calculate equality of circuits. In particular, this function does *not* attempt to
/// simplify or reduce neutral terms or products. So this comparison will return false for
/// `a == a + 0`. It will also return false for `XFieldElement(7) == BFieldElement(7)`
fn eq(&self, other: &Self) -> bool {
match &self.expression {
XConstant(self_xfe) => match &other.expression {
Expand Down Expand Up @@ -375,17 +393,15 @@ impl<II: InputIndicator> ConstraintCircuit<II> {
}
}

/// Apply constant folding to simplify the (sub)tree.
/// If the subtree is a leaf (terminal), no change.
/// If the subtree is a binary operation on:
/// Apply constant folding to simplify the (sub)tree. If the subtree is a leaf (terminal), no
/// change. If the subtree is a binary operation on:
///
/// - one constant x one constant => fold
/// - one constant x one expr => can't
/// - one expr x one constant => can't
/// - one expr x one expr => can't
///
/// This operation mutates self and returns true if a change was
/// applied anywhere in the tree.
/// This operation mutates self and returns true if a change was applied anywhere in the tree.
fn constant_fold_inner(&mut self) -> bool {
let mut change_tracker = false;
if let BinaryOperation(_, lhs, rhs) = &self.expression {
Expand Down Expand Up @@ -595,8 +611,8 @@ impl<II: InputIndicator> ConstraintCircuit<II> {
!matches!(&self.expression, BinaryOperation(_, _, _))
}

/// Return true if this node represents a constant value of zero, does not
/// catch composite expressions that will always evaluate to zero.
/// Return true if this node represents a constant value of zero, does not catch composite
/// expressions that will always evaluate to zero.
pub fn is_zero(&self) -> bool {
match self.expression {
BConstant(bfe) => bfe.is_zero(),
Expand All @@ -605,8 +621,8 @@ impl<II: InputIndicator> ConstraintCircuit<II> {
}
}

/// Return true if this node represents a constant value of one, does not
/// catch composite expressions that will always evaluate to one.
/// Return true if this node represents a constant value of one, does not catch composite
/// expressions that will always evaluate to one.
pub fn is_one(&self) -> bool {
match self.expression {
XConstant(xfe) => xfe.is_one(),
Expand All @@ -615,7 +631,7 @@ impl<II: InputIndicator> ConstraintCircuit<II> {
}
}

/// Return true iff the evaluation value of this node depends on a challenge
/// Return true iff the evaluation value of this node depends on a challenge.
pub fn is_randomized(&self) -> bool {
match &self.expression {
Challenge(_) => true,
Expand All @@ -626,7 +642,7 @@ impl<II: InputIndicator> ConstraintCircuit<II> {
}
}

/// Replace all challenges with constants in subtree
/// Replace all challenges with constants in subtree.
fn apply_challenges_to_one_root(&mut self, challenges: &Challenges) {
match &self.expression {
Challenge(challenge_id) => {
Expand All @@ -644,7 +660,7 @@ impl<II: InputIndicator> ConstraintCircuit<II> {
}
}

/// Simplify the circuit constraints by replacing the known challenges with roots
/// Simplify the circuit constraints by replacing the known challenges with roots.
pub fn apply_challenges(constraints: &mut [ConstraintCircuit<II>], challenges: &Challenges) {
for circuit in constraints.iter_mut() {
circuit.apply_challenges_to_one_root(challenges);
Expand Down Expand Up @@ -685,6 +701,8 @@ impl<II: InputIndicator> ConstraintCircuit<II> {
}
}

/// The inner type used in the [`ConstraintCircuitBuilder`] to build a circuit. Provides
/// convenience methods, for example by allowing to use `+` and `*` to add and multiply.
#[derive(Clone)]
pub struct ConstraintCircuitMonad<II: InputIndicator> {
pub circuit: Rc<RefCell<ConstraintCircuit<II>>>,
Expand Down Expand Up @@ -726,9 +744,9 @@ impl<II: InputIndicator> PartialEq for ConstraintCircuitMonad<II> {

impl<II: InputIndicator> Eq for ConstraintCircuitMonad<II> {}

/// Helper function for binary operations that are used to generate new parent
/// nodes in the multitree that represents the algebraic circuit. Ensures that
/// each newly created node has a unique ID.
/// Helper function for binary operations that are used to generate new parent nodes in the
/// multitree that represents the algebraic circuit. Ensures that each newly created node has a
/// unique ID.
fn binop<II: InputIndicator>(
binop: BinOp,
lhs: ConstraintCircuitMonad<II>,
Expand Down Expand Up @@ -819,9 +837,8 @@ impl<II: InputIndicator> Mul for ConstraintCircuitMonad<II> {
}
}

/// This will panic if the iterator is empty because the neutral
/// element needs a unique ID, and we have no way of getting that
/// here.
/// This will panic if the iterator is empty because the neutral element needs a unique ID, and
/// we have no way of getting that here.
impl<II: InputIndicator> Sum for ConstraintCircuitMonad<II> {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.reduce(|accum, item| accum + item)
Expand Down Expand Up @@ -858,25 +875,25 @@ impl<II: InputIndicator> ConstraintCircuitBuilder<II> {
}
}

/// Create constant leaf node
/// Create constant leaf node.
pub fn x_constant(&self, xfe: XFieldElement) -> ConstraintCircuitMonad<II> {
let expression = XConstant(xfe);
self.make_leaf(expression)
}

/// Create constant leaf node
/// Create constant leaf node.
pub fn b_constant(&self, bfe: BFieldElement) -> ConstraintCircuitMonad<II> {
let expression = BConstant(bfe);
self.make_leaf(expression)
}

/// Create deterministic input leaf node
/// Create deterministic input leaf node.
pub fn input(&self, input: II) -> ConstraintCircuitMonad<II> {
let expression = Input(input);
self.make_leaf(expression)
}

/// Create challenge leaf node
/// Create challenge leaf node.
pub fn challenge(&self, challenge_id: ChallengeId) -> ConstraintCircuitMonad<II> {
let expression = Challenge(challenge_id);
self.make_leaf(expression)
Expand Down

0 comments on commit 9386878

Please sign in to comment.