Skip to content

Commit

Permalink
Made sierra generation external stack based. (#6303)
Browse files Browse the repository at this point in the history
  • Loading branch information
orizi authored Aug 29, 2024
1 parent 9e57d80 commit be5bc2a
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 41 deletions.
91 changes: 60 additions & 31 deletions crates/cairo-lang-sierra-generator/src/block_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
#[path = "block_generator_test.rs"]
mod test;

use cairo_lang_defs::diagnostic_utils::StableLocation;
use cairo_lang_diagnostics::Maybe;
use cairo_lang_lowering::BlockId;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use itertools::{chain, enumerate, zip_eq};
use lowering::borrow_check::analysis::StatementLocation;
Expand Down Expand Up @@ -89,13 +91,46 @@ fn add_drop_statements(
Ok(())
}

/// Elements to be processed for block code generation.
enum BlockGenStackElement {
/// Generated code for the given block.
Block(BlockId),
/// Output the given sierra statement.
Statement(pre_sierra::Statement),
/// Configuration for the following blocks.
Config { starting_cairo_location: Vec<StableLocation>, ap_tracking_state: bool },
}

/// Generates Sierra statements for a function from the given [ExprGeneratorContext].
///
/// Returns a vector of Sierra statements.
pub fn generate_function_statements(
mut context: ExprGeneratorContext<'_>,
) -> Maybe<Vec<StatementWithLocation>> {
let mut block_gen_stack = vec![BlockGenStackElement::Block(BlockId::root())];
while let Some(element) = block_gen_stack.pop() {
match element {
BlockGenStackElement::Block(block_id) => {
generate_block_code(&mut context, &mut block_gen_stack, block_id)?
}
BlockGenStackElement::Statement(statement) => context.push_statement(statement),
BlockGenStackElement::Config { starting_cairo_location, ap_tracking_state } => {
context.curr_cairo_location = starting_cairo_location;
context.set_ap_tracking(ap_tracking_state);
}
}
}
Ok(context.statements())
}

/// Generates Sierra for a given [lowering::FlatBlock].
///
/// Returns a list of Sierra statements.
/// Assumes `block_id` exists in `self.lowered.blocks`.
pub fn generate_block_code(
fn generate_block_code(
context: &mut ExprGeneratorContext<'_>,
block_id: lowering::BlockId,
block_gen_stack: &mut Vec<BlockGenStackElement>,
block_id: BlockId,
) -> Maybe<()> {
let block = context.get_lowered_block(block_id);
let statement_location: StatementLocation = (block_id, block.statements.len());
Expand Down Expand Up @@ -124,8 +159,7 @@ pub fn generate_block_code(
id: *context.block_label(*target_block_id),
});
context.push_statement(label);

generate_block_code(context, *target_block_id)?;
block_gen_stack.push(BlockGenStackElement::Block(*target_block_id));
} else {
let jump = jump_statement(
jump_libfunc_id(context.get_db()),
Expand All @@ -152,13 +186,13 @@ pub fn generate_block_code(

match info {
lowering::MatchInfo::Extern(s) => {
generate_match_extern_code(context, s, &statement_location)
generate_match_extern_code(context, block_gen_stack, s, &statement_location)
}
lowering::MatchInfo::Enum(s) => {
generate_match_enum_code(context, s, &statement_location)
generate_match_enum_code(context, block_gen_stack, s, &statement_location)
}
lowering::MatchInfo::Value(s) => {
generate_match_value_code(context, s, &statement_location)
generate_match_value_code(context, block_gen_stack, s, &statement_location)
}
}?;
}
Expand Down Expand Up @@ -402,6 +436,7 @@ fn maybe_add_dup_statement(
/// Generates Sierra code for [lowering::MatchExternInfo].
fn generate_match_extern_code(
context: &mut ExprGeneratorContext<'_>,
block_gen_stack: &mut Vec<BlockGenStackElement>,
match_info: &lowering::MatchExternInfo,
statement_location: &StatementLocation,
) -> Maybe<()> {
Expand All @@ -411,13 +446,13 @@ fn generate_match_extern_code(
let (_function_long_id, libfunc_id) =
get_concrete_libfunc_id(context.get_db(), match_info.function, false);

generate_match_code(context, libfunc_id, args, &match_info.arms)?;
Ok(())
generate_match_code(context, block_gen_stack, libfunc_id, args, &match_info.arms)
}
/// Generates Sierra code for the match a [lowering::MatchExternInfo] or [lowering::MatchEnumInfo]
/// statement.
fn generate_match_code(
context: &mut ExprGeneratorContext<'_>,
block_gen_stack: &mut Vec<BlockGenStackElement>,
libfunc_id: ConcreteLibfuncId,
args: Vec<sierra::ids::VarId>,
arms: &[lowering::MatchArm],
Expand Down Expand Up @@ -456,29 +491,23 @@ fn generate_match_code(
let ap_tracking_enabled = context.get_ap_tracking();

let match_block_location = std::mem::take(&mut context.curr_cairo_location);
block_gen_stack.push(BlockGenStackElement::Statement(end_label));
// Generate the blocks.
for (i, MatchArm { arm_selector: _, block_id, var_ids: _ }) in enumerate(arms) {
context.curr_cairo_location.clone_from(&match_block_location);
// Reset ap_tracking to the state before the match.
context.set_ap_tracking(ap_tracking_enabled);

// Add a label for each of the arm blocks, except for the first.
if i > 0 {
context.push_statement(arm_labels[i - 1].0.clone());
}
// Add branch_align to equalize gas costs across the merging paths.
context.push_statement(simple_basic_statement(
for (i, MatchArm { arm_selector: _, block_id, var_ids: _ }) in enumerate(arms).rev() {
block_gen_stack.push(BlockGenStackElement::Block(*block_id));
block_gen_stack.push(BlockGenStackElement::Statement(simple_basic_statement(
branch_align_libfunc_id(context.get_db()),
&[],
&[],
));

generate_block_code(context, *block_id)?;
)));
if i > 0 {
block_gen_stack.push(BlockGenStackElement::Statement(arm_labels[i - 1].0.clone()));
}
block_gen_stack.push(BlockGenStackElement::Config {
starting_cairo_location: match_block_location.clone(),
ap_tracking_state: ap_tracking_enabled,
});
}

// Post match.
context.push_statement(end_label);

Ok(())
}

Expand Down Expand Up @@ -545,6 +574,7 @@ fn generate_statement_struct_destructure_code(
/// enum_from_bounded_int libfunc.
fn generate_match_value_code(
context: &mut ExprGeneratorContext<'_>,
block_gen_stack: &mut Vec<BlockGenStackElement>,
match_info: &lowering::MatchEnumValue,
statement_location: &StatementLocation,
) -> Maybe<()> {
Expand All @@ -563,13 +593,13 @@ fn generate_match_value_code(
let libfunc_id = match_enum_libfunc_id(context.get_db(), concrete_enum_type)?;

let args = vec![enum_var];
generate_match_code(context, libfunc_id, args, &match_info.arms)?;
Ok(())
generate_match_code(context, block_gen_stack, libfunc_id, args, &match_info.arms)
}

/// Generates Sierra code for [lowering::MatchEnumInfo].
fn generate_match_enum_code(
context: &mut ExprGeneratorContext<'_>,
block_gen_stack: &mut Vec<BlockGenStackElement>,
match_info: &lowering::MatchEnumInfo,
statement_location: &StatementLocation,
) -> Maybe<()> {
Expand All @@ -580,8 +610,7 @@ fn generate_match_enum_code(
let libfunc_id = match_enum_libfunc_id(context.get_db(), concrete_enum_type)?;

let args = vec![matched_enum];
generate_match_code(context, libfunc_id, args, &match_info.arms)?;
Ok(())
generate_match_code(context, block_gen_stack, libfunc_id, args, &match_info.arms)
}

/// Generates Sierra code for [lowering::StatementSnapshot].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use cairo_lang_filesystem::flag::Flag;
use cairo_lang_filesystem::ids::FlagId;
use cairo_lang_lowering as lowering;
use cairo_lang_lowering::db::LoweringGroup;
use cairo_lang_lowering::BlockId;
use cairo_lang_semantic::test_utils::setup_test_function;
use cairo_lang_test_utils::parse_test_file::TestRunnerResult;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
Expand All @@ -15,7 +14,7 @@ use cairo_lang_utils::UpcastMut;
use lowering::fmt::LoweredFormatter;
use lowering::ids::ConcreteFunctionWithBodyId;

use super::generate_block_code;
use super::generate_function_statements;
use crate::expr_generator_context::ExprGeneratorContext;
use crate::lifetime::find_variable_lifetime;
use crate::replace_ids::replace_sierra_ids;
Expand Down Expand Up @@ -77,7 +76,7 @@ fn block_generator_test(
// Generate (pre-)Sierra statements.
let lifetime = find_variable_lifetime(&lowered, &OrderedHashSet::default())
.expect("Failed to retrieve lifetime information.");
let mut expr_generator_context = ExprGeneratorContext::new(
let expr_generator_context = ExprGeneratorContext::new(
db,
&lowered,
function_id,
Expand All @@ -87,8 +86,7 @@ fn block_generator_test(

let mut expected_sierra_code = String::default();

generate_block_code(&mut expr_generator_context, BlockId::root()).unwrap();
for statement in expr_generator_context.statements() {
for statement in generate_function_statements(expr_generator_context).unwrap() {
expected_sierra_code.push_str(&replace_sierra_ids(db, &statement).statement.to_string(db));
expected_sierra_code.push('\n');
}
Expand Down
7 changes: 2 additions & 5 deletions crates/cairo-lang-sierra-generator/src/function_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ use cairo_lang_sierra::ids::ConcreteLibfuncId;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
use cairo_lang_utils::Intern;
use lowering::BlockId;

use crate::block_generator::generate_block_code;
use crate::block_generator::generate_function_statements;
use crate::db::SierraGenGroup;
use crate::expr_generator_context::ExprGeneratorContext;
use crate::lifetime::{find_variable_lifetime, SierraGenVar};
Expand Down Expand Up @@ -114,9 +113,7 @@ fn get_function_code(
}

// Generate the function's code.
generate_block_code(&mut context, BlockId::root())?;
let db = context.get_db();
let statements = context.statements();
let statements = generate_function_statements(context)?;

let statements = add_store_statements(
db,
Expand Down

0 comments on commit be5bc2a

Please sign in to comment.