Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pilopt two pass pipeline #2573

Closed
wants to merge 16 commits into from
6 changes: 6 additions & 0 deletions backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ pub trait BackendFactory<F: FieldElement> {
fn generate_setup(&self, _size: DegreeType, _output: &mut dyn io::Write) -> Result<(), Error> {
Err(Error::NoSetupAvailable)
}

fn specialize_pil(&self, pil: Analyzed<F>) -> Analyzed<F> {
// TODO: currently defaults to the identity function
// Move `bus_multi_linker` calls here in the future
pil
}
}

/// Dynamic interface for a backend.
Expand Down
25 changes: 9 additions & 16 deletions backend/src/plonky3/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,8 @@ where
#[cfg(test)]
mod tests {

use super::Plonky3Prover;
use powdr_number::{BabyBearField, GoldilocksField, Mersenne31Field};
use powdr_pipeline::Pipeline;
use powdr_pipeline::{BackendType, Pipeline};
use test_log::test;

use powdr_plonky3::{Commitment, FieldElementMap, ProverData};
Expand All @@ -349,26 +348,20 @@ mod tests {
ProverData<F>: Send + serde::Serialize + for<'a> serde::Deserialize<'a>,
Commitment<F>: Send,
{
let mut pipeline = Pipeline::<F>::default().from_pil_string(pil.to_string());
let pil = pipeline.compute_optimized_pil().unwrap();
let witness_callback = pipeline.witgen_callback().unwrap();
let witness = &mut pipeline.compute_witness().unwrap();
let fixed = pipeline.compute_fixed_cols().unwrap();
let mut pipeline = Pipeline::<F>::default()
.with_backend(BackendType::Plonky3, None)
.from_pil_string(pil.to_string());

let mut prover = Plonky3Prover::new(pil, fixed);
prover.setup();
let proof = prover.prove(witness, witness_callback);

assert!(proof.is_ok());
let proof = pipeline.compute_proof().unwrap().clone();

if let Some(publics) = malicious_publics {
prover
pipeline
.verify(
&proof.unwrap(),
&publics
&proof,
&[publics
.iter()
.map(|i| F::from(*i as u64))
.collect::<Vec<_>>(),
.collect::<Vec<_>>()],
)
.unwrap()
}
Expand Down
12 changes: 11 additions & 1 deletion cli-rs/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use env_logger::fmt::Color;
use env_logger::{Builder, Target};
use log::LevelFilter;

use powdr::backend::BackendType;
use powdr::number::{
BabyBearField, BigUint, Bn254Field, FieldElement, GoldilocksField, KnownField, KoalaBearField,
};
Expand Down Expand Up @@ -161,6 +162,11 @@ enum Commands {
#[arg(value_parser = clap_enum_variants!(FieldArgument))]
field: FieldArgument,

/// The backend to run witgen for
#[arg(short, long)]
#[arg(value_parser = clap_enum_variants!(BackendType))]
backend: BackendType,

/// Comma-separated list of free inputs (numbers).
#[arg(short, long)]
#[arg(default_value_t = String::new())]
Expand Down Expand Up @@ -304,6 +310,7 @@ fn run_command(command: Commands) {
Commands::Witgen {
file,
field,
backend,
inputs,
output_directory,
continuations,
Expand All @@ -326,6 +333,7 @@ fn run_command(command: Commands) {
};
call_with_field!(execute::<field>(
Path::new(&file),
backend,
split_inputs(&inputs),
Path::new(&output_directory),
continuations,
Expand Down Expand Up @@ -409,6 +417,7 @@ fn execute_fast<F: FieldElement>(
#[allow(clippy::too_many_arguments)]
fn execute<F: FieldElement>(
file_name: &Path,
backend: BackendType,
inputs: Vec<F>,
output_dir: &Path,
continuations: bool,
Expand All @@ -417,6 +426,7 @@ fn execute<F: FieldElement>(
) -> Result<(), Vec<String>> {
let mut pipeline = Pipeline::<F>::default()
.from_asm_file(file_name.to_path_buf())
.with_backend(backend, None)
.with_prover_inputs(inputs)
.with_output(output_dir.into(), true);

Expand All @@ -432,7 +442,7 @@ fn execute<F: FieldElement>(
} else {
let fixed = pipeline.compute_fixed_cols().unwrap().clone();
let asm = pipeline.compute_analyzed_asm().unwrap().clone();
let pil = pipeline.compute_optimized_pil().unwrap();
let pil = pipeline.compute_backend_tuned_pil().unwrap().clone();

let start = Instant::now();

Expand Down
2 changes: 1 addition & 1 deletion cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ fn run<F: FieldElement>(
) -> Result<(), Vec<String>> {
pipeline = pipeline.with_setup_file(params.map(PathBuf::from));

pipeline.compute_witness().unwrap();
pipeline.compute_optimized_pil().unwrap();

if let Some(backend) = prove_with {
pipeline
Expand Down
9 changes: 6 additions & 3 deletions pipeline/benches/jit_witgen_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use ::powdr_pipeline::Pipeline;
use powdr_backend::BackendType;
use powdr_number::GoldilocksField;

use criterion::{criterion_group, criterion_main, Criterion};
Expand All @@ -10,9 +11,11 @@ fn jit_witgen_benchmark(c: &mut Criterion) {
group.sample_size(10);

// Poseidon benchmark
let mut pipeline =
Pipeline::<T>::default().from_file("../test_data/std/poseidon_benchmark.asm".into());
pipeline.compute_optimized_pil().unwrap();
let mut pipeline = Pipeline::<T>::default()
.from_file("../test_data/std/poseidon_benchmark.asm".into())
.with_backend(BackendType::Mock, None);
// this `jit_witgen_benchmark` function will also require backend type
pipeline.compute_backend_tuned_pil().unwrap();
pipeline.compute_fixed_cols().unwrap();

group.bench_function("jit_witgen_benchmark", |b| {
Expand Down
62 changes: 50 additions & 12 deletions pipeline/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ pub struct Artifacts<T: FieldElement> {
/// An analyzed .pil file, with all dependencies imported, potentially from other files.
analyzed_pil: Option<Analyzed<T>>,
/// An optimized .pil file.
optimized_pil: Option<Arc<Analyzed<T>>>,
optimized_pil: Option<Analyzed<T>>,
/// A .pil file after backend-specific tuning and another optimization pass.
backend_tuned_pil: Option<Arc<Analyzed<T>>>,
/// Fully evaluated fixed columns.
fixed_cols: Option<Arc<VariablySizedColumns<T>>>,
/// Generated witnesses.
Expand Down Expand Up @@ -168,6 +170,7 @@ impl<T: FieldElement> Clone for Artifacts<T> {
pil_string: self.pil_string.clone(),
analyzed_pil: self.analyzed_pil.clone(),
optimized_pil: self.optimized_pil.clone(),
backend_tuned_pil: self.backend_tuned_pil.clone(),
fixed_cols: self.fixed_cols.clone(),
witness: self.witness.clone(),
proof: self.proof.clone(),
Expand Down Expand Up @@ -367,6 +370,14 @@ impl<T: FieldElement> Pipeline<T> {
self
}

pub fn with_backend_if_none(&mut self, backend: BackendType, options: Option<BackendOptions>) {
if self.arguments.backend.is_none() {
self.arguments.backend = Some(backend);
self.arguments.backend_options = options.unwrap_or_default();
self.artifact.backend = None;
}
}

pub fn with_setup_file(mut self, setup_file: Option<PathBuf>) -> Self {
self.arguments.setup_file = setup_file;
self.artifact.backend = None;
Expand Down Expand Up @@ -483,7 +494,7 @@ impl<T: FieldElement> Pipeline<T> {

Ok(Pipeline {
artifact: Artifacts {
optimized_pil: Some(Arc::new(analyzed)),
optimized_pil: Some(analyzed),
..Default::default()
},
name,
Expand Down Expand Up @@ -958,9 +969,9 @@ impl<T: FieldElement> Pipeline<T> {
Ok(self.artifact.analyzed_pil.as_ref().unwrap())
}

pub fn compute_optimized_pil(&mut self) -> Result<Arc<Analyzed<T>>, Vec<String>> {
pub fn compute_optimized_pil(&mut self) -> Result<&Analyzed<T>, Vec<String>> {
if let Some(ref optimized_pil) = self.artifact.optimized_pil {
return Ok(optimized_pil.clone());
return Ok(optimized_pil);
}

self.compute_analyzed_pil()?;
Expand All @@ -971,21 +982,48 @@ impl<T: FieldElement> Pipeline<T> {
self.maybe_write_pil(&optimized, "_opt")?;
self.maybe_write_pil_object(&optimized, "_opt")?;

self.artifact.optimized_pil = Some(Arc::new(optimized));
self.artifact.optimized_pil = Some(optimized);

Ok(self.artifact.optimized_pil.as_ref().unwrap())
}

pub fn optimized_pil(&self) -> Result<&Analyzed<T>, Vec<String>> {
Ok(self.artifact.optimized_pil.as_ref().unwrap())
}

pub fn compute_backend_tuned_pil(&mut self) -> Result<Arc<Analyzed<T>>, Vec<String>> {
if let Some(ref backend_tuned_pil) = self.artifact.backend_tuned_pil {
return Ok(backend_tuned_pil.clone());
}

self.compute_optimized_pil()?;

let backend_type = self.arguments.backend.expect("no backend selected!");

// If backend option is set, compute and cache the backend-tuned pil in artifacts and return backend-tuned pil.
let optimized_pil = self.artifact.optimized_pil.clone().unwrap();
let factory = backend_type.factory::<T>();
self.log("Apply backend-specific tuning to optimized pil...");
let backend_tuned_pil = factory.specialize_pil(optimized_pil);
self.log("Optimizing pil (post backend-specific tuning)...");
let reoptimized_pil = powdr_pilopt::optimize(backend_tuned_pil);
self.maybe_write_pil(&reoptimized_pil, "_backend_tuned")?;
self.maybe_write_pil_object(&reoptimized_pil, "_backend_tuned")?;
self.artifact.backend_tuned_pil = Some(Arc::new(reoptimized_pil));

Ok(self.artifact.optimized_pil.as_ref().unwrap().clone())
Ok(self.artifact.backend_tuned_pil.as_ref().unwrap().clone())
}

pub fn optimized_pil(&self) -> Result<Arc<Analyzed<T>>, Vec<String>> {
Ok(self.artifact.optimized_pil.as_ref().unwrap().clone())
pub fn backend_tuned_pil(&self) -> Result<Arc<Analyzed<T>>, Vec<String>> {
Ok(self.artifact.backend_tuned_pil.as_ref().unwrap().clone())
}

pub fn compute_fixed_cols(&mut self) -> Result<Arc<VariablySizedColumns<T>>, Vec<String>> {
if let Some(ref fixed_cols) = self.artifact.fixed_cols {
return Ok(fixed_cols.clone());
}

let pil = self.compute_optimized_pil()?;
let pil = self.compute_backend_tuned_pil()?; // will panic if backend type is not set yet

self.log("Evaluating fixed columns...");
let start = Instant::now();
Expand All @@ -1012,7 +1050,7 @@ impl<T: FieldElement> Pipeline<T> {

self.host_context.clear();

let pil = self.compute_optimized_pil()?;
let pil = self.compute_backend_tuned_pil()?; // will panic if backend type is not set yet
let fixed_cols = self.compute_fixed_cols()?;

assert_eq!(pil.constant_count(), fixed_cols.len());
Expand Down Expand Up @@ -1072,7 +1110,7 @@ impl<T: FieldElement> Pipeline<T> {
}

pub fn publics(&self) -> Result<Vec<(String, Option<T>)>, Vec<String>> {
let pil = self.optimized_pil()?;
let pil = self.backend_tuned_pil()?; // will panic if backend type is not set yet
let witness = self.witness()?;
Ok(extract_publics(witness.iter().map(|(k, v)| (k, v)), &pil)
.into_iter()
Expand All @@ -1099,7 +1137,7 @@ impl<T: FieldElement> Pipeline<T> {
if self.artifact.backend.is_some() {
return Ok(self.artifact.backend.as_deref_mut().unwrap());
}
let pil = self.compute_optimized_pil()?;
let pil = self.compute_backend_tuned_pil()?; // will panic if backend type is not set yet
let fixed_cols = self.compute_fixed_cols()?;

let backend = self.arguments.backend.expect("no backend selected!");
Expand Down
Loading
Loading