Skip to content

Commit

Permalink
prover gas estimate in sp1-prover
Browse files Browse the repository at this point in the history
  • Loading branch information
tqn committed Jan 25, 2025
1 parent 73ab98a commit 53ff6a2
Show file tree
Hide file tree
Showing 13 changed files with 295 additions and 46 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion crates/core/executor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ sp1-zkvm = { workspace = true, features = ["lib"] }
test-artifacts = { workspace = true }

[features]
default = ["gas"] # REMOVE ME BEFORE MERGING
bigint-rug = ["sp1-curves/bigint-rug"]
profiling = [
"dep:goblin",
Expand Down
45 changes: 38 additions & 7 deletions crates/core/executor/src/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,46 @@ impl RiscvAirId {
]
}

// Whether the trace generation for this AIR is deferred.
pub(crate) fn is_deferred(self) -> bool {
/// TODO replace these three with subenums or something
/// Whether the ID represents a core AIR.
#[must_use]
pub fn is_core(self) -> bool {
matches!(
self,
RiscvAirId::Cpu
| RiscvAirId::AddSub
| RiscvAirId::Mul
| RiscvAirId::Bitwise
| RiscvAirId::ShiftLeft
| RiscvAirId::ShiftRight
| RiscvAirId::DivRem
| RiscvAirId::Lt
| RiscvAirId::Auipc
| RiscvAirId::MemoryLocal
| RiscvAirId::MemoryInstrs
| RiscvAirId::Branch
| RiscvAirId::Jump
| RiscvAirId::SyscallCore
| RiscvAirId::SyscallInstrs
| RiscvAirId::Global,
)
}

/// Whether the ID represents a memory AIR.
#[must_use]
pub fn is_memory(self) -> bool {
matches!(
self,
RiscvAirId::MemoryGlobalInit | RiscvAirId::MemoryGlobalFinalize | RiscvAirId::Global
)
}

/// Whether the ID represents a precompile AIR.
#[must_use]
pub fn is_precompile(self) -> bool {
matches!(
self,
// Global memory.
RiscvAirId::MemoryGlobalInit
| RiscvAirId::MemoryGlobalFinalize
// Precompiles.
| RiscvAirId::ShaExtend
RiscvAirId::ShaExtend
| RiscvAirId::ShaCompress
| RiscvAirId::EdAddAssign
| RiscvAirId::EdDecompress
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
//! Data that may be collected during execution and used to estimate trace area.
use std::ops::AddAssign;

use enum_map::EnumMap;
use hashbrown::HashMap;
use sp1_stark::SP1CoreOpts;
Expand All @@ -6,21 +10,28 @@ use crate::RiscvAirId;

const BYTE_NUM_ROWS: u64 = 1 << 16;

#[derive(Default, Clone)]
/// Data accumulated during execution to estimate the core trace area used to prove the execution.
#[derive(Clone, Debug, Default)]
pub struct TraceAreaEstimator {
pub core_area: u64,
/// Core shards, represented by the number of events per AIR.
pub core_shards: Vec<EnumMap<RiscvAirId, u64>>,
/// Deferred events, which are used to calculate trace area after execution has finished.
pub deferred_events: EnumMap<RiscvAirId, u64>,
}

impl TraceAreaEstimator {
/// An estimate of the total trace area required for the core proving stage.
/// This provides a prover gas metric.
#[must_use]
#[deprecated]
pub fn total_trace_area(
&self,
program_len: usize,
costs: &HashMap<RiscvAirId, u64>,
opts: &SP1CoreOpts,
) -> u64 {
let core_area = 0u64;

let deferred_area = self
.deferred_events
.iter()
Expand Down Expand Up @@ -52,21 +63,18 @@ impl TraceAreaEstimator {
// // Compute the program chip contribution.
let program_area = program_len as u64 * costs[&RiscvAirId::Program];

self.core_area + deferred_area + byte_area + program_area
core_area + deferred_area + byte_area + program_area
}
}

/// Mark the end of a shard. Estimates the area of core AIRs and defers appropriate counts.
pub(crate) fn flush_shard(
&mut self,
event_counts: &EnumMap<RiscvAirId, u64>,
costs: &HashMap<RiscvAirId, u64>,
) {
for (id, count) in event_counts {
if id.is_deferred() {
self.deferred_events[id] += count;
} else {
self.core_area += costs[&id] * count.next_power_of_two();
}
}
impl AddAssign for TraceAreaEstimator {
fn add_assign(&mut self, rhs: Self) {
let TraceAreaEstimator { core_shards, deferred_events } = self;
core_shards.extend(rhs.core_shards);
deferred_events
.as_mut_array()
.iter_mut()
.zip(rhs.deferred_events.as_array())
.for_each(|(l, r)| *l += r);
}
}
16 changes: 4 additions & 12 deletions crates/core/executor/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{fs::File, io::BufWriter};
use std::{str::FromStr, sync::Arc};

#[cfg(feature = "gas")]
use crate::gas::TraceAreaEstimator;
use crate::estimator::TraceAreaEstimator;
#[cfg(feature = "profiling")]
use crate::profiler::Profiler;
use clap::ValueEnum;
Expand Down Expand Up @@ -383,16 +383,6 @@ impl<'a> Executor<'a> {
HookEnv { runtime: self }
}

/// An estimate of the total trace area required for the core proving stage.
/// This provides a prover gas metric.
#[cfg(feature = "gas")]
#[must_use]
pub fn total_trace_area(&self) -> Option<u64> {
self.trace_area_estimator.as_ref().map(|estimator| {
estimator.total_trace_area(self.program.instructions.len(), &self.costs, &self.opts)
})
}

/// Recover runtime state from a program and existing execution state.
#[must_use]
pub fn recover(program: Program, state: ExecutionState, opts: SP1CoreOpts) -> Self {
Expand Down Expand Up @@ -1780,13 +1770,15 @@ impl<'a> Executor<'a> {

/// Bump the record.
pub fn bump_record(&mut self) {
#[cfg(feature = "gas")]
if let Some(estimator) = &mut self.trace_area_estimator {
Self::estimate_riscv_event_counts(
&mut self.event_counts,
(self.state.clk >> 2) as u64,
&self.local_counts,
);
estimator.flush_shard(&self.event_counts, &self.costs);
// The above method estimates event counts only for core shards.
estimator.core_shards.push(self.event_counts);
}
self.local_counts = LocalCounts::default();
// Copy all of the existing local memory accesses to the record's local_memory_access vec.
Expand Down
4 changes: 2 additions & 2 deletions crates/core/executor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ mod context;
mod cost;
mod dependencies;
mod disassembler;
#[cfg(feature = "gas")]
pub mod estimator;
pub mod events;
mod executor;
#[cfg(feature = "gas")]
mod gas;
mod hook;
mod instruction;
mod io;
Expand Down
7 changes: 5 additions & 2 deletions crates/core/machine/src/riscv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,10 @@ impl<F: PrimeField32> RiscvAir<F> {
]
}

pub(crate) fn precompile_airs_with_memory_events_per_row() -> Vec<(Self, usize)> {
/// Internal function.
///
/// Returns the number of memory events per row of each precompile. Used in estimating trace area.
pub fn precompile_airs_with_memory_events_per_row() -> HashMap<RiscvAirId, usize> {
let mut airs: HashSet<_> = Self::get_airs_and_costs().0.into_iter().collect();

// Remove the core airs.
Expand Down Expand Up @@ -550,7 +553,7 @@ impl<F: PrimeField32> RiscvAir<F> {
})
.count();

(chip.into_inner(), local_mem_events_per_row)
(chip.into_inner().id(), local_mem_events_per_row)
})
.collect()
}
Expand Down
8 changes: 4 additions & 4 deletions crates/core/machine/src/shape/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ impl<F: PrimeField32> CoreShapeConfig<F> {
// program.
record.shape.clone_from(&record.program.preprocessed_shape);

let shape = self.find_shape(record)?;
let shape = self.find_shape(&*record)?;
record.shape.as_mut().unwrap().extend(shape);
Ok(())
}

/// TODO move this into the executor crate
pub fn find_shape<R: Shapeable<F>>(
&self,
record: &R,
record: R,
) -> Result<Shape<RiscvAirId>, CoreShapeError> {
match record.kind() {
// If this is a packed "core" record where the cpu events are alongisde the memory init and
Expand Down Expand Up @@ -425,7 +425,7 @@ impl<F: PrimeField32> CoreShapeConfig<F> {
self.maximal_core_shapes(max_log_shard_size).into_iter().chain(precompile_shapes).collect()
}

fn estimate_lde_size(&self, shape: &Shape<RiscvAirId>) -> usize {
pub fn estimate_lde_size(&self, shape: &Shape<RiscvAirId>) -> usize {
shape.iter().map(|(air, height)| self.costs[air] * (1 << height)).sum()
}

Expand Down Expand Up @@ -510,7 +510,7 @@ impl<F: PrimeField32> Default for CoreShapeConfig<F> {
RiscvAir::<F>::precompile_airs_with_memory_events_per_row()
{
precompile_allowed_log2_heights
.insert(air.id(), (memory_events_per_row, precompile_heights.clone()));
.insert(air, (memory_events_per_row, precompile_heights.clone()));
}

Self {
Expand Down
43 changes: 43 additions & 0 deletions crates/core/machine/src/shape/shapeable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use sp1_stark::MachineRecord;

use crate::memory::NUM_LOCAL_MEMORY_ENTRIES_PER_ROW;

#[derive(Debug, Clone, Copy)]
pub enum ShardKind {
PackedCore,
Core,
Expand All @@ -26,6 +27,48 @@ pub trait Shapeable<F: PrimeField32> {
fn precompile_heights(&self) -> impl Iterator<Item = (RiscvAirId, (usize, usize, usize))>;
}

macro_rules! impl_for_ref {
($ty:ty) => {
impl<F: PrimeField32, T> Shapeable<F> for $ty
where
T: Shapeable<F>,
{
fn kind(&self) -> ShardKind {
<Self as std::ops::Deref>::deref(self).kind()
}

fn shard(&self) -> u32 {
<Self as std::ops::Deref>::deref(self).shard()
}

fn log2_shard_size(&self) -> usize {
<Self as std::ops::Deref>::deref(self).log2_shard_size()
}

fn debug_stats(&self) -> HashMap<String, usize> {
<Self as std::ops::Deref>::deref(self).debug_stats()
}

fn core_heights(&self) -> Vec<(RiscvAirId, usize)> {
<Self as std::ops::Deref>::deref(self).core_heights()
}

fn memory_heights(&self) -> Vec<(RiscvAirId, usize)> {
<Self as std::ops::Deref>::deref(self).memory_heights()
}

fn precompile_heights(
&self,
) -> impl Iterator<Item = (RiscvAirId, (usize, usize, usize))> {
<Self as std::ops::Deref>::deref(self).precompile_heights()
}
}
};
}

impl_for_ref!(&T);
impl_for_ref!(&mut T);

impl<F: PrimeField32> Shapeable<F> for ExecutionRecord {
fn kind(&self) -> ShardKind {
let contains_global_memory = !self.global_memory_initialize_events.is_empty()
Expand Down
6 changes: 5 additions & 1 deletion crates/prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ thiserror = "1.0.63"
rayon = "1.10.0"
lru = "0.12.4"
eyre = "0.6.12"
hashbrown = { workspace = true, features = ["inline-more"], optional = true }
enum-map = { version = "2.7.3", optional = true }

[build-dependencies]
downloader = { version = "0.2", default-features = false, features = [
"rustls-tls",
"verify",
]}
] }
sha2 = { version = "0.10" }
hex = "0.4"

Expand Down Expand Up @@ -92,5 +94,7 @@ name = "post_trusted_setup"
path = "scripts/post_trusted_setup.rs"

[features]
default = ["gas"] # REMOVE ME BEFORE MERGING
native-gnark = ["sp1-recursion-gnark-ffi/native"]
debug = ["sp1-core-machine/debug"]
gas = ["dep:hashbrown", "dep:enum-map", "sp1-core-executor/gas"]
Loading

0 comments on commit 53ff6a2

Please sign in to comment.