Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split compile function and use interpreter for large code. #2569

Merged
merged 10 commits into from
Mar 24, 2025
108 changes: 82 additions & 26 deletions executor/src/witgen/jit/function_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,18 @@ use crate::witgen::{
use super::{
block_machine_processor::BlockMachineProcessor,
compiler::{compile_effects, CompiledFunction},
effect::Effect,
interpreter::EffectsInterpreter,
prover_function_heuristics::ProverFunction,
variable::Variable,
witgen_inference::CanProcessCall,
};

/// Inferred witness generation routines that are larger than
/// this number of "statements" will use the interpreter instead of the compiler
/// due to the large compilation ressources required.
const MAX_COMPILED_CODE_SIZE: usize = 500;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be good to leave a small comment here for usage/context


#[derive(Debug, Clone, Hash, PartialEq, Eq)]
struct CacheKey<T: FieldElement> {
bus_id: T,
Expand Down Expand Up @@ -136,28 +143,26 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
) -> &Option<CacheEntry<T>> {
if !self.witgen_functions.contains_key(cache_key) {
record_start("Auto-witgen code derivation");
let f = match T::known_field() {
// TODO: Currently, code generation only supports the Goldilocks
// fields. We can't enable the interpreter for non-goldilocks
// fields due to a limitation of autowitgen.
Some(KnownField::GoldilocksField) => {
self.compile_witgen_function(can_process, cache_key, false)
}
_ => None,
};
assert!(self.witgen_functions.insert(cache_key.clone(), f).is_none());
let compiled = self
.derive_witgen_function(can_process, cache_key)
.and_then(|(result, prover_functions)| {
self.compile_witgen_function(result, prover_functions, cache_key)
});
assert!(self
.witgen_functions
.insert(cache_key.clone(), compiled)
.is_none());

record_end("Auto-witgen code derivation");
}
self.witgen_functions.get(cache_key).unwrap()
}

fn compile_witgen_function(
fn derive_witgen_function(
&self,
can_process: impl CanProcessCall<T>,
cache_key: &CacheKey<T>,
interpreted: bool,
) -> Option<CacheEntry<T>> {
) -> Option<(ProcessorResult<T>, Vec<ProverFunction<'a, T>>)> {
log::debug!(
"Compiling JIT function for\n Machine: {}\n Connection: {}\n Inputs: {:?}{}",
self.machine_name,
Expand All @@ -169,13 +174,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
.unwrap_or_default()
);

let (
ProcessorResult {
code,
range_constraints,
},
prover_functions,
) = self
let (processor_result, prover_functions) = self
.processor
.generate_code(
can_process,
Expand All @@ -193,7 +192,8 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
.ok()?;

log::debug!("=> Success!");
let out_of_bounds_vars = code
let out_of_bounds_vars = processor_result
.code
.iter()
.flat_map(|effect| effect.referenced_variables())
.filter_map(|var| match var {
Expand All @@ -203,7 +203,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
.filter(|cell| cell.row_offset < -1 || cell.row_offset >= self.block_size as i32)
.collect_vec();
if !out_of_bounds_vars.is_empty() {
log::debug!("Code:\n{}", format_code(&code));
log::debug!("Code:\n{}", format_code(&processor_result.code));
panic!(
"Expected JITed code to only reference cells in the block + the last row \
of the previous block, i.e. rows -1 until (including) {}, but it does reference the following:\n{}",
Expand All @@ -212,25 +212,51 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
);
}

log::trace!("Generated code ({} steps)", code.len());
log::trace!("Generated code ({} steps)", processor_result.code.len());
Some((processor_result, prover_functions))
}

fn compile_witgen_function(
&self,
result: ProcessorResult<T>,
prover_functions: Vec<ProverFunction<'a, T>>,
cache_key: &CacheKey<T>,
) -> Option<CacheEntry<T>> {
let known_inputs = cache_key
.known_args
.iter()
.enumerate()
.filter_map(|(i, b)| if b { Some(Variable::Param(i)) } else { None })
.collect::<Vec<_>>();

let has_prover_function_call = has_prover_function_call(&result.code);

// TODO This is the goal, but we need to implement prover unctions for the interpreter first.

// Use the compiler for goldilocks with at most MAX_COMPILED_CODE_SIZE statements and
// the interpreter otherwise.
#[allow(unused)]
let interpreted = !matches!(T::known_field(), Some(KnownField::GoldilocksField))
|| code_size(&result.code) > MAX_COMPILED_CODE_SIZE;

let interpreted = !matches!(T::known_field(), Some(KnownField::GoldilocksField));

if interpreted && has_prover_function_call {
log::debug!("Interpreter does not yet implement prover functions.");
return None;
}

let function = if interpreted {
log::trace!("Building effects interpreter...");
WitgenFunction::Interpreted(EffectsInterpreter::try_new(&known_inputs, &code)?)
WitgenFunction::Interpreted(EffectsInterpreter::new(&known_inputs, &result.code))
} else {
log::trace!("Compiling effects...");
WitgenFunction::Compiled(
compile_effects(
self.fixed_data.analyzed,
self.column_layout.clone(),
&known_inputs,
&code,
&result.code,
prover_functions,
)
.unwrap(),
Expand All @@ -240,7 +266,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {

Some(CacheEntry {
function,
range_constraints,
range_constraints: result.range_constraints,
})
}

Expand Down Expand Up @@ -281,3 +307,33 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
Ok(true)
}
}

/// Returns the elements in the code and thus a rough estimate of the number of steps
fn code_size<T: FieldElement>(code: &[Effect<T, Variable>]) -> usize {
code.iter()
.map(|effect| match effect {
Effect::Assignment(..)
| Effect::Assertion(..)
| Effect::MachineCall(..)
| Effect::ProverFunctionCall(..) => 1,
Effect::RangeConstraint(..) => unreachable!(),
Effect::Branch(_, first, second) => code_size(first) + code_size(second) + 1,
})
.sum()
}

/// Returns true if there is any prover function call in the code.
fn has_prover_function_call<'a, T: FieldElement>(
code: impl IntoIterator<Item = &'a Effect<T, Variable>> + 'a,
) -> bool {
code.into_iter().any(|effect| match effect {
Effect::ProverFunctionCall(..) => true,
Effect::Branch(_, if_branch, else_branch) => {
has_prover_function_call(if_branch) || has_prover_function_call(else_branch)
}
Effect::Assignment(..)
| Effect::RangeConstraint(..)
| Effect::Assertion(..)
| Effect::MachineCall(..) => false,
})
}
24 changes: 3 additions & 21 deletions executor/src/witgen/jit/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,25 +89,7 @@ enum MachineCallArgumentIdx {
}

impl<T: FieldElement> EffectsInterpreter<T> {
pub fn try_new(known_inputs: &[Variable], effects: &[Effect<T, Variable>]) -> Option<Self> {
// TODO: interpreter doesn't support prover functions yet
fn has_prover_fn<T: FieldElement>(effect: &Effect<T, Variable>) -> bool {
match effect {
Effect::ProverFunctionCall(..) => true,
Effect::Branch(_, if_branch, else_branch) => {
if if_branch.iter().any(has_prover_fn) || else_branch.iter().any(has_prover_fn)
{
return true;
}
false
}
_ => false,
}
}
if effects.iter().any(has_prover_fn) {
return None;
}

pub fn new(known_inputs: &[Variable], effects: &[Effect<T, Variable>]) -> Self {
let mut actions = vec![];
let mut var_mapper = VariableMapper::new();

Expand All @@ -121,7 +103,7 @@ impl<T: FieldElement> EffectsInterpreter<T> {
actions,
};
assert!(actions_are_valid(&ret.actions, BTreeSet::new()));
Some(ret)
ret
}

/// Returns an iterator of actions to load all accessed fixed column values into variables.
Expand Down Expand Up @@ -680,7 +662,7 @@ mod test {
.unwrap();

// generate and call the interpreter
let interpreter = EffectsInterpreter::try_new(&known_inputs, &result.code).unwrap();
let interpreter = EffectsInterpreter::new(&known_inputs, &result.code);

Self {
analyzed,
Expand Down
Loading