Skip to content

Commit

Permalink
feat: parallel recursion tracegen (#1095)
Browse files Browse the repository at this point in the history
  • Loading branch information
ctian1 authored Jul 12, 2024
1 parent 7b74b8b commit f11e51a
Show file tree
Hide file tree
Showing 11 changed files with 283 additions and 306 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions core/src/runtime/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,6 @@ impl ExecutionRecord {
pub fn split(&mut self, last: bool, opts: SplitOpts) -> Vec<ExecutionRecord> {
let mut shards = Vec::new();

println!("keccak split {}", opts.keccak_split_threshold);

macro_rules! split_events {
($self:ident, $events:ident, $shards:ident, $threshold:expr, $exact:expr) => {
let events = std::mem::take(&mut $self.$events);
Expand Down
48 changes: 38 additions & 10 deletions core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub use programs::*;

use crate::{memory::MemoryCols, operations::field::params::Limbs};
use generic_array::ArrayLength;
use p3_maybe_rayon::prelude::{ParallelBridge, ParallelIterator};

pub const fn indices_arr<const N: usize>() -> [usize; N] {
let mut indices_arr = [0; N];
Expand Down Expand Up @@ -88,30 +89,36 @@ pub fn pad_rows_fixed<R: Clone>(
) {
let nb_rows = rows.len();
let dummy_row = row_fn();
match size_log2 {
Some(size_log2) => {
let padded_nb_rows = 1 << size_log2;
if nb_rows * 2 < padded_nb_rows {
rows.resize(next_power_of_two(nb_rows, size_log2), dummy_row);
}

/// Returns the next power of two that is >= `n` and >= 16. If `fixed_power` is set, it will return
/// `2^fixed_power` after checking that `n <= 2^fixed_power`.
pub fn next_power_of_two(n: usize, fixed_power: Option<usize>) -> usize {
match fixed_power {
Some(power) => {
let padded_nb_rows = 1 << power;
if n * 2 < padded_nb_rows {
tracing::warn!(
"fixed log2 rows can be potentially reduced: got {}, expected {}",
nb_rows,
n,
padded_nb_rows
);
}
if nb_rows > padded_nb_rows {
if n > padded_nb_rows {
panic!(
"fixed log2 rows is too small: got {}, expected {}",
nb_rows, padded_nb_rows
n, padded_nb_rows
);
}
rows.resize(padded_nb_rows, dummy_row);
padded_nb_rows
}
None => {
let mut padded_nb_rows = nb_rows.next_power_of_two();
let mut padded_nb_rows = n.next_power_of_two();
if padded_nb_rows < 16 {
padded_nb_rows = 16;
}
rows.resize(padded_nb_rows, dummy_row);
padded_nb_rows
}
}
}
Expand Down Expand Up @@ -186,3 +193,24 @@ pub fn log2_strict_usize(n: usize) -> usize {
assert_eq!(n.wrapping_shr(res), 1, "Not a power of two: {n}");
res as usize
}

pub fn par_for_each_row<P, F>(vec: &mut [F], num_cols: usize, processor: P)
where
F: Send,
P: Fn(usize, &mut [F]) + Send + Sync,
{
// Split the vector into `num_cpus` chunks, but at least `num_cpus` rows per chunk.
let len = vec.len();
let cpus = num_cpus::get();
let ceil_div = (len + cpus - 1) / cpus;
let chunk_size = std::cmp::max(ceil_div, cpus);

vec.chunks_mut(chunk_size * num_cols)
.enumerate()
.par_bridge()
.for_each(|(i, chunk)| {
chunk.chunks_mut(num_cols).enumerate().for_each(|(j, row)| {
processor(i * chunk_size + j, row);
});
});
}
8 changes: 2 additions & 6 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ use sp1_recursion_program::machine::{
pub use sp1_recursion_program::machine::{
SP1DeferredMemoryLayout, SP1RecursionMemoryLayout, SP1ReduceMemoryLayout, SP1RootMemoryLayout,
};
use tracing::instrument;
use tracing::{info_span, instrument};

Check warning on line 62 in prover/src/lib.rs

View workflow job for this annotation

GitHub Actions / Test

unused import: `info_span`

Check warning on line 62 in prover/src/lib.rs

View workflow job for this annotation

GitHub Actions / Plonk Docker

unused import: `info_span`

Check warning on line 62 in prover/src/lib.rs

View workflow job for this annotation

GitHub Actions / Plonk Native

unused import: `info_span`

Check warning on line 62 in prover/src/lib.rs

View workflow job for this annotation

GitHub Actions / Test (x86-64)

unused import: `info_span`

Check warning on line 62 in prover/src/lib.rs

View workflow job for this annotation

GitHub Actions / Test (x86-64)

unused import: `info_span`

Check failure on line 62 in prover/src/lib.rs

View workflow job for this annotation

GitHub Actions / Formatting & Clippy

unused import: `info_span`

Check warning on line 62 in prover/src/lib.rs

View workflow job for this annotation

GitHub Actions / Test (ARM)

unused import: `info_span`

Check warning on line 62 in prover/src/lib.rs

View workflow job for this annotation

GitHub Actions / Test (ARM)

unused import: `info_span`
pub use types::*;
use utils::words_to_bytes;

Expand Down Expand Up @@ -295,10 +295,6 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
for batch in shard_proofs.chunks(batch_size) {
let proofs = batch.to_vec();

let public_values: &PublicValues<Word<BabyBear>, BabyBear> =
proofs.last().unwrap().public_values.as_slice().borrow();
println!("core execution shard: {}", public_values.execution_shard);

core_inputs.push(SP1RecursionMemoryLayout {
vk,
machine: self.core_prover.machine(),
Expand Down Expand Up @@ -517,6 +513,7 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
})
}

/// Generate a proof with the compress machine.
pub fn compress_machine_proof(
&self,
input: impl Hintable<InnerConfig>,
Expand All @@ -533,7 +530,6 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
witness_stream.extend(input.write());

runtime.witness_stream = witness_stream.into();

runtime
.run()
.map_err(|e| SP1RecursionProverError::RuntimeError(e.to_string()))?;
Expand Down
1 change: 1 addition & 0 deletions recursion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ serde_with = "3.8.3"
backtrace = { version = "0.3.71", features = ["serde"] }
arrayref = "0.3.7"
static_assertions = "1.1.0"
num_cpus = "1.16.0"

[dev-dependencies]
rand = "0.8.5"
10 changes: 1 addition & 9 deletions recursion/core/src/cpu/columns/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use std::mem::{size_of, transmute};
use std::mem::size_of;

use crate::memory::{MemoryReadCols, MemoryReadWriteCols};
use p3_air::BaseAir;
use sp1_core::utils::indices_arr;
use sp1_derive::AlignedBorrow;

mod branch;
Expand All @@ -23,13 +22,6 @@ use super::CpuChip;

pub const NUM_CPU_COLS: usize = size_of::<CpuCols<u8>>();

const fn make_col_map() -> CpuCols<usize> {
let indices_arr = indices_arr::<NUM_CPU_COLS>();
unsafe { transmute::<[usize; NUM_CPU_COLS], CpuCols<usize>>(indices_arr) }
}

pub(crate) const CPU_COL_MAP: CpuCols<usize> = make_col_map();

impl<F: Send + Sync, const L: usize> BaseAir<F> for CpuChip<F, L> {
fn width(&self) -> usize {
NUM_CPU_COLS
Expand Down
Loading

0 comments on commit f11e51a

Please sign in to comment.