Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] JIT: Inline witness assignments #2559

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 67 additions & 43 deletions executor/src/witgen/jit/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ fn witgen_code<T: FieldElement>(
let load_known_inputs = known_inputs
.iter()
.map(|v| {
let var_name = variable_to_string(v);
let value = match v {
Variable::WitnessCell(c) => {
format!("get(data, row_offset, {}, {})", c.row_offset, c.id)
Expand All @@ -237,7 +236,8 @@ fn witgen_code<T: FieldElement>(
unreachable!("Machine call variables should not be pre-known.")
}
};
format!(" let {var_name} = {value};")

indent(set(v, &value, true, false), 1)
})
.format("\n");

Expand All @@ -252,12 +252,11 @@ fn witgen_code<T: FieldElement>(
})
.unique()
.map(|(var, cell)| {
format!(
" let {} = get_fixed_value(fixed_data, {}, (row_offset + {}));",
variable_to_string(var),
cell.id,
cell.row_offset,
)
let value = format!(
"get_fixed_value(fixed_data, {}, (row_offset + {}))",
cell.id, cell.row_offset
);
indent(set(var, &value, true, false), 1)
})
.format("\n");

Expand All @@ -270,12 +269,9 @@ fn witgen_code<T: FieldElement>(
let store_values = vars_known
.iter()
.filter_map(|var| {
let value = variable_to_string(var);
let value = get(var);
match var {
Variable::WitnessCell(cell) => Some(format!(
" set(data, row_offset, {}, {}, {value});",
cell.row_offset, cell.id,
)),
Variable::WitnessCell(_) => None,
Variable::Param(i) => Some(format!(" set_param(params, {i}, {value});")),
Variable::FixedCell(_) => panic!("Fixed columns should not be written to."),
Variable::IntermediateCell(_) => {
Expand Down Expand Up @@ -363,14 +359,7 @@ fn format_effects_inner<T: FieldElement>(

fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>, is_top_level: bool) -> String {
match effect {
Effect::Assignment(var, e) => {
format!(
"{}{} = {};",
if is_top_level { "let " } else { "" },
variable_to_string(var),
format_expression(e)
)
}
Effect::Assignment(var, e) => set(var, &format_expression(e), is_top_level, false),
Effect::RangeConstraint(..) => {
unreachable!("Final code should not contain pure range constraints.")
}
Expand All @@ -390,22 +379,30 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>, is_top_level: bo
.iter()
.zip_eq(known)
.map(|(v, known)| {
let var_name = variable_to_string(v);
if known {
format!("LookupCell::Input(&{var_name})")
let input = get(v);
format!("LookupCell::Input(&{input})")
} else {
if is_top_level {
result_vars.push(var_name.clone());
result_vars.push(v);
}
// Assumes that get returns a simple variable name.
let var_name = get(v);
format!("LookupCell::Output(&mut {var_name})")
}
})
.format(", ")
.to_string();
let var_decls = result_vars
.iter()
.map(|var_name| format!("let mut {var_name} = FieldElement::default();\n"))
.format("");
let var_decls = if result_vars.is_empty() {
"".to_string()
} else {
result_vars
.iter()
.map(|var_name| set(var_name, "FieldElement::default()", is_top_level, true))
.format("\n")
.to_string()
+ "\n"
};
format!(
"{var_decls}assert!(call_machine(mutable_state, {id}.into(), MutSlice::from((&mut [{args}]).as_mut_slice())));"
)
Expand All @@ -416,12 +413,17 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>, is_top_level: bo
row_offset,
inputs,
}) => {
format!(
"{}[{}] = prover_function_{function_index}(mutable_state, input_from_channel, output_to_channel, row_offset + {row_offset}, &[{}]);",
if is_top_level { "let " } else { "" },
targets.iter().map(variable_to_string).format(", "),
inputs.iter().map(variable_to_string).format(", ")
)
let function_call = format!(
"let result = prover_function_{function_index}(mutable_state, input_from_channel, output_to_channel, row_offset + {row_offset}, &[{}]);",
inputs.iter().map(get).format(", ")
);
let store_results = targets
.iter()
.enumerate()
.map(|(i, v)| set(v, &format!("result[{i}]"), is_top_level, false))
.format("\n");
let block = format!("{function_call}\n{store_results}");
format!("{{\n{}\n}}", indent(block, 1))
}
Effect::Branch(condition, first, second) => {
let var_decls = if is_top_level {
Expand All @@ -434,14 +436,9 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>, is_top_level: bo
.sorted()
.dedup()
.map(|(v, needs_mut)| {
let v = variable_to_string(v);
if needs_mut {
format!("let mut {v} = FieldElement::default();\n")
} else {
format!("let {v};\n")
}
set(v, "FieldElement::default()", is_top_level, needs_mut)
})
.format("")
.format("\n")
.to_string()
} else {
"".to_string()
Expand All @@ -465,7 +462,7 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>, is_top_level: bo
fn format_expression<T: FieldElement>(e: &SymbolicExpression<T, Variable>) -> String {
match e {
SymbolicExpression::Concrete(v) => format!("FieldElement::from({v})"),
SymbolicExpression::Symbol(symbol, _) => variable_to_string(symbol),
SymbolicExpression::Symbol(symbol, _) => get(symbol),
SymbolicExpression::BinaryOperation(left, op, right, _) => {
let left = format_expression(left);
let right = format_expression(right);
Expand Down Expand Up @@ -498,7 +495,7 @@ fn format_condition<T: FieldElement>(
condition,
}: &BranchCondition<T, Variable>,
) -> String {
let var = format!("IntType::from({})", variable_to_string(variable));
let var = format!("IntType::from({})", get(variable));
let (min, max) = condition.range();
match min.cmp(&max) {
Ordering::Equal => format!("{var} == {min}",),
Expand All @@ -507,6 +504,33 @@ fn format_condition<T: FieldElement>(
}
}

fn set(v: &Variable, value: &str, is_top_level: bool, needs_mut: bool) -> String {
match v {
Variable::WitnessCell(cell) => {
format!(
"set(data, row_offset, {}, {}, {value});",
cell.row_offset, cell.id
)
}
_ => format!(
"{}{}{} = {};",
if is_top_level { "let " } else { "" },
if needs_mut { "mut " } else { "" },
variable_to_string(v),
value
),
}
}

fn get(v: &Variable) -> String {
match v {
Variable::WitnessCell(cell) => {
format!("get(data, row_offset, {}, {})", cell.row_offset, cell.id)
}
_ => variable_to_string(v),
}
}

/// Returns the name of a local (stack) variable for the given expression variable.
fn variable_to_string(v: &Variable) -> String {
match v {
Expand Down
12 changes: 6 additions & 6 deletions executor/src/witgen/jit/function_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
can_process: impl CanProcessCall<T>,
cache_key: &CacheKey<T>,
) -> Option<CacheEntry<T>> {
log::debug!(
log::info!(
"Compiling JIT function for\n Machine: {}\n Connection: {}\n Inputs: {:?}{}",
self.machine_name,
self.parts.bus_receives[&cache_key.bus_id],
Expand All @@ -152,14 +152,14 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
)
.map_err(|e| {
// These errors can be pretty verbose and are quite common currently.
log::debug!(
log::info!(
"=> Error generating JIT code: {}\n...",
e.to_string().lines().take(5).join("\n")
);
})
.ok()?;

log::debug!("=> Success!");
log::info!("=> Success!");
let out_of_bounds_vars = code
.iter()
.flat_map(|effect| effect.referenced_variables())
Expand All @@ -179,15 +179,15 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
);
}

log::trace!("Generated code ({} steps)", code.len());
log::info!("Generated code ({} steps)", code.len());
let known_inputs = cache_key
.known_args
.iter()
.enumerate()
.filter_map(|(i, b)| if b { Some(Variable::Param(i)) } else { None })
.collect::<Vec<_>>();

log::trace!("Compiling effects...");
log::info!("Compiling effects...");
let function = compile_effects(
self.fixed_data.analyzed,
self.column_layout.clone(),
Expand All @@ -196,7 +196,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
prover_functions,
)
.unwrap();
log::trace!("Compilation done.");
log::info!("Compilation done.");

Some(CacheEntry {
function,
Expand Down
18 changes: 16 additions & 2 deletions jit-compiler/src/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use mktemp::Temp;
use std::{
fs::{self},
path::PathBuf,
process::Command,
str::from_utf8,
sync::Arc,
Expand Down Expand Up @@ -83,11 +84,24 @@ fn cargo_toml(opt_level: Option<u32>) -> String {
}
}

const DEBUG: bool = false;

/// Compiles the given code and returns the path to the
/// temporary directory containing the compiled library
/// and the path to the compiled library.
pub fn call_cargo(code: &str, opt_level: Option<u32>) -> Result<PathInTempDir, String> {
let dir = mktemp::Temp::new_dir().unwrap();
let dir_tmp = mktemp::Temp::new_dir().unwrap();

let dir = if DEBUG {
let dir = PathBuf::from("../cargo_dir");
// rm -r cargo_dir/*
fs::remove_dir_all(&dir).ok();
fs::create_dir(&dir).unwrap();
dir
} else {
dir_tmp.as_path().to_path_buf()
};

fs::write(dir.join("Cargo.toml"), cargo_toml(opt_level)).unwrap();
fs::create_dir(dir.join("src")).unwrap();
fs::write(dir.join("src").join("lib.rs"), code).unwrap();
Expand Down Expand Up @@ -132,7 +146,7 @@ pub fn call_cargo(code: &str, opt_level: Option<u32>) -> Result<PathInTempDir, S
.join("release")
.join(format!("libpowdr_jit_compiled.{extension}"));
Ok(PathInTempDir {
dir,
dir: dir_tmp,
path: lib_path.to_str().unwrap().to_string(),
})
}
Expand Down
Loading