Skip to content

Commit

Permalink
Added mul const folding.
Browse files Browse the repository at this point in the history
commit-id:4b2734ae
  • Loading branch information
orizi committed Jul 29, 2024
1 parent 47287ad commit d35f25a
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 0 deletions.
16 changes: 16 additions & 0 deletions crates/cairo-lang-lowering/src/optimizations/const_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,15 @@ impl<'a> ConstFoldingContext<'a> {
if val.is_zero() {
self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
}
} else if self.wide_mul_fns.contains(&stmt.function) {
let lhs = self.as_int(stmt.inputs[0].var_id)?;
let rhs = self.as_int(stmt.inputs[1].var_id)?;
let value = lhs * rhs;
let output = stmt.outputs[0];
let ty = self.variables[output].ty;
let value = ConstValue::Int(value, ty);
self.var_info.insert(output, VarInfo::Const(value.clone()));
return Some(Statement::Const(StatementConst { value, output }));
} else if stmt.function == self.storage_base_address_from_felt252 {
let input_var = stmt.inputs[0].var_id;
if let Some(ConstValue::Int(val, ty)) = self.as_const(input_var) {
Expand Down Expand Up @@ -379,6 +388,8 @@ struct LibfuncInfo<'a> {
uadd_fns: UnorderedHashSet<FunctionId>,
/// The set of functions to subtract unsigned ints.
usub_fns: UnorderedHashSet<FunctionId>,
/// The set of functions to multiply integers.
wide_mul_fns: UnorderedHashSet<FunctionId>,
/// The storage access module.
storage_access_module: ModuleHelper<'a>,
/// Type ranges.
Expand Down Expand Up @@ -409,6 +420,10 @@ impl<'a> LibfuncInfo<'a> {
let usub_fns = UnorderedHashSet::<_>::from_iter(
utypes.map(|ty| integer_module.function_id(format!("{ty}_overflowing_sub"), vec![])),
);
let wide_mul_fns = UnorderedHashSet::<_>::from_iter(
["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
.map(|ty| integer_module.function_id(format!("{ty}_wide_mul"), vec![])),
);
let type_value_ranges = UnorderedHashMap::from_iter(
[
("u8", TypeRange::closed(0, u8::MAX)),
Expand Down Expand Up @@ -436,6 +451,7 @@ impl<'a> LibfuncInfo<'a> {
nz_fns,
uadd_fns,
usub_fns,
wide_mul_fns,
storage_access_module,
type_value_ranges,
}
Expand Down
121 changes: 121 additions & 0 deletions crates/cairo-lang-lowering/src/optimizations/test_data/const_folding
Original file line number Diff line number Diff line change
Expand Up @@ -1763,3 +1763,124 @@ End:
Return(v18)

//! > lowering_diagnostics

//! > ==========================================================================

//! > Mul const fold.

//! > test_runner_name
test_match_optimizer

//! > function
fn foo() -> u8 {
5 * 3
}

//! > function_name
foo

//! > module_code

//! > semantic_diagnostics

//! > before
Parameters:
blk0 (root):
Statements:
(v0: core::integer::u8) <- 5
(v1: core::integer::u8) <- 3
(v2: core::integer::u16) <- core::integer::u8_wide_mul(v0, v1)
End:
Match(match core::integer::downcast::<core::integer::u16, core::integer::u8>(v2) {
Option::Some(v3) => blk1,
Option::None => blk2,
})

blk1:
Statements:
(v4: (core::integer::u8,)) <- struct_construct(v3)
(v5: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Ok(v4)
End:
Goto(blk3, {v5 -> v6})

blk2:
Statements:
(v7: core::array::Array::<core::felt252>) <- core::array::array_new::<core::felt252>()
(v8: core::felt252) <- 608642107937639184217240406363762551
(v9: core::array::Array::<core::felt252>) <- core::array::array_append::<core::felt252>(v7, v8)
(v10: core::panics::Panic) <- struct_construct()
(v11: (core::panics::Panic, core::array::Array::<core::felt252>)) <- struct_construct(v10, v9)
(v12: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Err(v11)
End:
Goto(blk3, {v12 -> v6})

blk3:
Statements:
End:
Match(match_enum(v6) {
PanicResult::Ok(v13) => blk4,
PanicResult::Err(v14) => blk5,
})

blk4:
Statements:
(v17: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Ok(v13)
End:
Return(v17)

blk5:
Statements:
(v18: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Err(v14)
End:
Return(v18)

//! > after
Parameters:
blk0 (root):
Statements:
(v0: core::integer::u8) <- 5
(v1: core::integer::u8) <- 3
(v2: core::integer::u16) <- 15
(v3: core::integer::u8) <- 15
End:
Goto(blk1, {})

blk1:
Statements:
(v4: (core::integer::u8,)) <- struct_construct(v3)
(v5: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Ok(v4)
End:
Goto(blk3, {v5 -> v6})

blk2:
Statements:
(v7: core::array::Array::<core::felt252>) <- core::array::array_new::<core::felt252>()
(v8: core::felt252) <- 608642107937639184217240406363762551
(v9: core::array::Array::<core::felt252>) <- core::array::array_append::<core::felt252>(v7, v8)
(v10: core::panics::Panic) <- struct_construct()
(v11: (core::panics::Panic, core::array::Array::<core::felt252>)) <- struct_construct(v10, v9)
(v12: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Err(v11)
End:
Goto(blk3, {v12 -> v6})

blk3:
Statements:
End:
Match(match_enum(v6) {
PanicResult::Ok(v13) => blk4,
PanicResult::Err(v14) => blk5,
})

blk4:
Statements:
(v17: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Ok(v13)
End:
Return(v17)

blk5:
Statements:
(v18: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Err(v14)
End:
Return(v18)

//! > lowering_diagnostics

0 comments on commit d35f25a

Please sign in to comment.