Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stage 0 public reference #2556

Merged
merged 18 commits into from
Mar 24, 2025
2 changes: 1 addition & 1 deletion backend-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ pub fn referenced_namespaces_algebraic_expression<F: FieldElement>(
.all_children()
.filter_map(|expr| match expr {
AlgebraicExpression::Reference(reference) => Some(extract_namespace(&reference.name)),
AlgebraicExpression::PublicReference(_) => unimplemented!(),
AlgebraicExpression::PublicReference(name) => Some(extract_namespace(name)),
AlgebraicExpression::Challenge(_)
| AlgebraicExpression::Number(_)
| AlgebraicExpression::BinaryOperation(_)
Expand Down
8 changes: 7 additions & 1 deletion backend/src/composite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ impl<F: FieldElement> Backend<F> for CompositeBackend<F> {
fn prove(
&self,
witness: &[(String, Vec<F>)],
publics: &BTreeMap<String, Option<F>>,
prev_proof: Option<Proof>,
witgen_callback: WitgenCallback<F>,
) -> Result<Proof, Error> {
Expand Down Expand Up @@ -339,7 +340,12 @@ impl<F: FieldElement> Backend<F> for CompositeBackend<F> {
.expect("Machine does not support the given size");

let status = time_stage(machine, size, 0, || {
sub_prover::run(scope, &inner_machine_data.backend, witness)
sub_prover::run(
scope,
&inner_machine_data.backend,
witness,
publics.clone(),
)
});

Some((status, machine_entry, size))
Expand Down
6 changes: 5 additions & 1 deletion backend/src/composite/sub_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub fn run<'s, 'env, F: FieldElement>(
scope: &'s Scope<'s, 'env>,
prover: &'env Mutex<Box<dyn Backend<F>>>,
witness: Vec<(String, Vec<F>)>,
publics: BTreeMap<String, Option<F>>,
) -> RunStatus<'s, F>
where
{
Expand Down Expand Up @@ -44,7 +45,10 @@ where
// proof, even if it's not needed anymore. We should probably change
// this API so the Vec is moved into the prover, and returned in the
// callback and result.
prover.lock().unwrap().prove(&witness, None, callback)
prover
.lock()
.unwrap()
.prove(&witness, &publics, None, callback)
});

SubProver {
Expand Down
2 changes: 2 additions & 0 deletions backend/src/estark/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod polygon_wrapper;
pub mod starky_wrapper;

use std::{
collections::BTreeMap,
fs::File,
io::{self, BufWriter, Write},
iter::{once, repeat},
Expand Down Expand Up @@ -261,6 +262,7 @@ impl<F: FieldElement> Backend<F> for DumpBackend<F> {
fn prove(
&self,
witness: &[(String, Vec<F>)],
_publics: &BTreeMap<String, Option<F>>,
prev_proof: Option<Proof>,
// TODO: Implement challenges
_witgen_callback: WitgenCallback<F>,
Expand Down
3 changes: 2 additions & 1 deletion backend/src/estark/polygon_wrapper.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{fs, path::PathBuf, sync::Arc};
use std::{collections::BTreeMap, fs, path::PathBuf, sync::Arc};

use powdr_ast::analyzed::Analyzed;
use powdr_executor::{
Expand Down Expand Up @@ -49,6 +49,7 @@ impl<F: FieldElement> Backend<F> for PolygonBackend<F> {
fn prove(
&self,
witness: &[(String, Vec<F>)],
_publics: &BTreeMap<String, Option<F>>,
prev_proof: Option<Proof>,
// TODO: Implement challenges
_witgen_callback: WitgenCallback<F>,
Expand Down
2 changes: 2 additions & 0 deletions backend/src/estark/starky_wrapper.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
Expand Down Expand Up @@ -192,6 +193,7 @@ impl Backend<GoldilocksField> for EStark {
fn prove(
&self,
witness: &[(String, Vec<GoldilocksField>)],
_publics: &BTreeMap<String, Option<GoldilocksField>>,
prev_proof: Option<crate::Proof>,
// TODO: Implement challenges
_witgen_callback: WitgenCallback<GoldilocksField>,
Expand Down
3 changes: 3 additions & 0 deletions backend/src/halo2/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::io;
use std::path::PathBuf;
use std::sync::Arc;
Expand Down Expand Up @@ -144,6 +145,7 @@ impl Backend<Bn254Field> for Halo2Prover {
fn prove(
&self,
witness: &[(String, Vec<Bn254Field>)],
_publics: &BTreeMap<String, Option<Bn254Field>>,
prev_proof: Option<Proof>,
witgen_callback: WitgenCallback<Bn254Field>,
) -> Result<Proof, Error> {
Expand Down Expand Up @@ -231,6 +233,7 @@ impl<T: FieldElement> Backend<T> for Halo2Mock<T> {
fn prove(
&self,
witness: &[(String, Vec<T>)],
_publics: &BTreeMap<String, Option<T>>,
prev_proof: Option<Proof>,
witgen_callback: WitgenCallback<T>,
) -> Result<Proof, Error> {
Expand Down
3 changes: 2 additions & 1 deletion backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ mod mock;
use powdr_ast::analyzed::Analyzed;
use powdr_executor::{constant_evaluator::VariablySizedColumn, witgen::WitgenCallback};
use powdr_number::{DegreeType, FieldElement};
use std::{io, path::PathBuf, sync::Arc};
use std::{collections::BTreeMap, io, path::PathBuf, sync::Arc};
use strum::{Display, EnumString, EnumVariantNames};

#[derive(Clone, EnumString, EnumVariantNames, Display, Copy)]
Expand Down Expand Up @@ -197,6 +197,7 @@ pub trait Backend<F: FieldElement>: Send {
fn prove(
&self,
witness: &[(String, Vec<F>)],
publics: &BTreeMap<String, Option<F>>,
prev_proof: Option<Proof>,
witgen_callback: WitgenCallback<F>,
) -> Result<Proof, Error>;
Expand Down
1 change: 1 addition & 0 deletions backend/src/mock/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl<F: FieldElement> Backend<F> for MockBackend<F> {
fn prove(
&self,
witness: &[(String, Vec<F>)],
_publics: &BTreeMap<String, Option<F>>,
prev_proof: Option<Proof>,
witgen_callback: WitgenCallback<F>,
) -> Result<Proof, Error> {
Expand Down
5 changes: 3 additions & 2 deletions backend/src/plonky3/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mod stark;

use std::{io, path::PathBuf, sync::Arc};
use std::{collections::BTreeMap, io, path::PathBuf, sync::Arc};

use powdr_ast::analyzed::Analyzed;
use powdr_executor::{constant_evaluator::VariablySizedColumn, witgen::WitgenCallback};
Expand Down Expand Up @@ -71,14 +71,15 @@ where
fn prove(
&self,
witness: &[(String, Vec<T>)],
publics: &BTreeMap<String, Option<T>>,
prev_proof: Option<Proof>,
witgen_callback: WitgenCallback<T>,
) -> Result<Proof, Error> {
if prev_proof.is_some() {
return Err(Error::NoAggregationAvailable);
}

Ok(self.prove(witness, witgen_callback)?)
Ok(self.prove(witness, publics, witgen_callback)?)
}

fn export_verification_key(&self, output: &mut dyn io::Write) -> Result<(), Error> {
Expand Down
1 change: 1 addition & 0 deletions backend/src/plonky3/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ where
pub fn prove(
&self,
witness: &[(String, Vec<T>)],
_publics: &BTreeMap<String, Option<T>>,
witgen_callback: WitgenCallback<T>,
) -> Result<Vec<u8>, String> {
let mut witness_by_machine = self
Expand Down
2 changes: 2 additions & 0 deletions backend/src/stwo/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::collections::BTreeMap;
use std::io;
use std::path::PathBuf;
use std::sync::Arc;
Expand Down Expand Up @@ -79,6 +80,7 @@ where
fn prove(
&self,
witness: &[(String, Vec<M31>)],
_publics: &BTreeMap<String, Option<M31>>,
prev_proof: Option<Proof>,
witgen_callback: WitgenCallback<M31>,
) -> Result<Proof, Error> {
Expand Down
48 changes: 30 additions & 18 deletions executor/src/witgen/data_structures/mutable_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ impl<'a, T: FieldElement, Q: QueryCallback<T>> MutableState<'a, T, Q> {

/// Runs the first machine (unless there are no machines) end returns the generated columns.
/// The first machine might call other machines, which is handled automatically.
pub fn run(self) -> HashMap<String, Vec<T>> {
pub fn run(self) -> (HashMap<String, Vec<T>>, BTreeMap<String, T>) {
if let Some(first_machine) = self.machines.first() {
first_machine.try_borrow_mut().unwrap().run_timed(&self);
}
self.take_witness_col_values()
self.take_witness_col_and_public_values()
}

pub fn can_process_call_fully(
Expand Down Expand Up @@ -96,25 +96,37 @@ impl<'a, T: FieldElement, Q: QueryCallback<T>> MutableState<'a, T, Q> {
}

/// Extracts the witness column values from the machines.
fn take_witness_col_values(self) -> HashMap<String, Vec<T>> {
// We keep the already processed machines mutably borrowed so that
// "later" machines do not try to create new rows in already processed
// machines.
let mut processed = vec![];
self.machines
fn take_witness_col_and_public_values(self) -> (HashMap<String, Vec<T>>, BTreeMap<String, T>) {
let witness_columns = {
// We keep the already processed machines mutably borrowed so that
// "later" machines do not try to create new rows in already processed
// machines.
let mut processed = vec![];
self.machines
.iter()
.flat_map(|machine| {
let mut machine = machine.try_borrow_mut().unwrap_or_else(|_| {
panic!("Recursive machine dependencies while finishing machines.");
});
let columns = machine.take_witness_col_values(&self).into_iter();
processed.push(machine);
columns
})
.collect()
};

let public_values = self
.machines
.iter()
.flat_map(|machine| {
let mut machine = machine
.try_borrow_mut()
.map_err(|_| {
panic!("Recursive machine dependencies while finishing machines.");
})
.unwrap();
let columns = machine.take_witness_col_values(&self).into_iter();
processed.push(machine);
columns
let mut machine = machine.try_borrow_mut().unwrap_or_else(|_| {
panic!("Recursive machine dependencies while finishing machines.");
});
machine.take_public_values().into_iter()
})
.collect()
.collect();

(witness_columns, public_values)
}

pub fn query_callback(&self) -> &Q {
Expand Down
7 changes: 7 additions & 0 deletions executor/src/witgen/machines/block_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,13 @@ impl<'a, T: FieldElement> Machine<'a, T> for BlockMachine<'a, T> {
.map(|(id, values)| (self.fixed_data.column_name(&id).to_string(), values))
.collect()
}

fn take_public_values(&mut self) -> BTreeMap<String, T> {
std::mem::take(&mut self.publics)
.into_iter()
.map(|(key, value)| (key.to_string(), value))
.collect()
}
}

impl<'a, T: FieldElement> BlockMachine<'a, T> {
Expand Down
14 changes: 11 additions & 3 deletions executor/src/witgen/machines/dynamic_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ impl<'a, T: FieldElement> Machine<'a, T> for DynamicMachine<'a, T> {
fn run<Q: QueryCallback<T>>(&mut self, mutable_state: &MutableState<'a, T, Q>) {
assert!(self.data.is_empty());
let first_row = self.compute_partial_first_row(mutable_state);
self.data = self
let process_result = self
.process(first_row, 0, mutable_state, None, true)
.updated_data
.block;
.updated_data;
self.data = process_result.block;
self.publics.extend(process_result.publics);
}

fn process_plookup<'b, Q: QueryCallback<T>>(
Expand Down Expand Up @@ -124,6 +125,13 @@ impl<'a, T: FieldElement> Machine<'a, T> for DynamicMachine<'a, T> {
.map(|(id, values)| (self.fixed_data.column_name(&id).to_string(), values))
.collect()
}

fn take_public_values(&mut self) -> BTreeMap<String, T> {
std::mem::take(&mut self.publics)
.into_iter()
.map(|(key, value)| (key.to_string(), value))
.collect()
}
}

impl<'a, T: FieldElement> DynamicMachine<'a, T> {
Expand Down
8 changes: 8 additions & 0 deletions executor/src/witgen/machines/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ pub trait Machine<'a, T: FieldElement>: Send + Sync {

/// Returns the identity IDs of the connecting identities that this machine is responsible for.
fn bus_ids(&self) -> Vec<T>;

fn take_public_values(&mut self) -> BTreeMap<String, T> {
BTreeMap::new()
}
}

#[repr(C)]
Expand Down Expand Up @@ -239,6 +243,10 @@ impl<'a, T: FieldElement> Machine<'a, T> for KnownMachine<'a, T> {
match_variant!(self, m => m.take_witness_col_values(mutable_state))
}

fn take_public_values(&mut self) -> BTreeMap<String, T> {
match_variant!(self, m => m.take_public_values())
}

fn bus_ids(&self) -> Vec<T> {
match_variant!(self, m => m.bus_ids())
}
Expand Down
14 changes: 10 additions & 4 deletions executor/src/witgen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ mod vm_processor;
pub use affine_expression::{AffineExpression, AffineResult, AlgebraicVariable};
pub use evaluators::partial_expression_evaluator::{PartialExpressionEvaluator, SymbolicVariables};

pub type Witness<T> = Vec<(String, Vec<T>)>;
pub type Publics<T> = BTreeMap<String, Option<T>>;
Comment on lines +53 to +54
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it would make sense to introduce a struct for both? Whenever you need the witness, you'll also need the publics. So then the pipeline would have a field witness_and_publics: Option<Arc<WitnessAndPublics<T>>> instead of two fields.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Took a slightly different approach by always returning (witness, publics) in the main APIs. Kept them separate so that in helper functions we can return only witness or only publics.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pipeline now has witness_and_publics as a single field.


static OUTER_CODE_NAME: &str = "witgen (outer code)";

// TODO change this so that it has functions
Expand Down Expand Up @@ -141,6 +144,7 @@ impl<T: FieldElement> WitgenCallbackContext<T> {
.with_external_witness_values(current_witness)
.with_challenges(stage, challenges)
.generate()
.0
}
}
}
Expand Down Expand Up @@ -202,7 +206,7 @@ impl<'a, 'b, T: FieldElement> WitnessGenerator<'a, 'b, T> {

/// Generates the committed polynomial values
/// @returns the values (in source order) and the degree of the polynomials.
pub fn generate(self) -> Vec<(String, Vec<T>)> {
pub fn generate(self) -> (Witness<T>, Publics<T>) {
record_start(OUTER_CODE_NAME);
let fixed = FixedData::new(
self.analyzed,
Expand Down Expand Up @@ -241,7 +245,8 @@ impl<'a, 'b, T: FieldElement> WitnessGenerator<'a, 'b, T> {
let machines = MachineExtractor::new(&fixed).split_out_machines();

// Run main machine and extract columns from all machines.
let columns = MutableState::new(machines.into_iter(), &self.query_callback).run();
let (columns, _publics) =
MutableState::new(machines.into_iter(), &self.query_callback).run();
Comment on lines 247 to +249
Copy link
Collaborator Author

@qwang98 qwang98 Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that _publics is not passed to the backend yet till I also implement stage-1 publics. However, the pipeline for passing publics to the backend is ready (rest of the PR).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this confuses me still: run() returns some publics, but then they are ignored and instead we use extract_publics(&columns, self.analyzed).

Is it because at this point we'd expect them to be the same?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually intentional. This PR technically still uses witness columns to derive public values as you identified in extract_publics, because using witgen public reference values requires adding public references to all existing examples that have public declarations, which is quite involved and requires lots of debugging, so I deferred it to another PR downstream: #2572.

In this other PR, extract_publics is updated to use public reference values only: https://github.com/powdr-labs/powdr/pull/2572/files#diff-06c2d788cbdf696f356c8bfb53645c75075218b7b65aa5456a76b33d9d122e87L294-R307

This PR focuses more on making the pipeline work.


let publics = extract_publics(&columns, self.analyzed);
if !publics.is_empty() {
Expand All @@ -258,7 +263,7 @@ impl<'a, 'b, T: FieldElement> WitnessGenerator<'a, 'b, T> {

let mut columns = if self.stage == 0 {
// Multiplicities should be computed in the first stage
MultiplicityColumnGenerator::new(&fixed).generate(columns, publics)
MultiplicityColumnGenerator::new(&fixed).generate(columns, publics.clone())
} else {
columns
};
Expand All @@ -279,7 +284,8 @@ impl<'a, 'b, T: FieldElement> WitnessGenerator<'a, 'b, T> {
(name, column)
})
.collect::<Vec<_>>();
witness_cols

(witness_cols, publics)
}
}

Expand Down
Loading
Loading