diff --git a/crates/cairo-lang-lowering/src/optimizations/const_folding.rs b/crates/cairo-lang-lowering/src/optimizations/const_folding.rs index 912e62fe781..af774a21c9e 100644 --- a/crates/cairo-lang-lowering/src/optimizations/const_folding.rs +++ b/crates/cairo-lang-lowering/src/optimizations/const_folding.rs @@ -188,6 +188,15 @@ impl<'a> ConstFoldingContext<'a> { if val.is_zero() { self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0])); } + } else if self.wide_mul_fns.contains(&stmt.function) { + let lhs = self.as_int(stmt.inputs[0].var_id)?; + let rhs = self.as_int(stmt.inputs[1].var_id)?; + let value = lhs * rhs; + let output = stmt.outputs[0]; + let ty = self.variables[output].ty; + let value = ConstValue::Int(value, ty); + self.var_info.insert(output, VarInfo::Const(value.clone())); + return Some(Statement::Const(StatementConst { value, output })); } else if stmt.function == self.storage_base_address_from_felt252 { let input_var = stmt.inputs[0].var_id; if let Some(ConstValue::Int(val, ty)) = self.as_const(input_var) { @@ -379,6 +388,8 @@ struct LibfuncInfo<'a> { uadd_fns: UnorderedHashSet, /// The set of functions to subtract unsigned ints. usub_fns: UnorderedHashSet, + /// The set of functions to multiply integers. + wide_mul_fns: UnorderedHashSet, /// The storage access module. storage_access_module: ModuleHelper<'a>, /// Type ranges. @@ -409,6 +420,10 @@ impl<'a> LibfuncInfo<'a> { let usub_fns = UnorderedHashSet::<_>::from_iter( utypes.map(|ty| integer_module.function_id(format!("{ty}_overflowing_sub"), vec![])), ); + let wide_mul_fns = UnorderedHashSet::<_>::from_iter( + ["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"] + .map(|ty| integer_module.function_id(format!("{ty}_wide_mul"), vec![])), + ); let type_value_ranges = UnorderedHashMap::from_iter( [ ("u8", TypeRange::closed(0, u8::MAX)), @@ -436,6 +451,7 @@ impl<'a> LibfuncInfo<'a> { nz_fns, uadd_fns, usub_fns, + wide_mul_fns, storage_access_module, type_value_ranges, } diff --git a/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding b/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding index 5d7bde9b0b6..3bcaa17167e 100644 --- a/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding +++ b/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding @@ -1763,3 +1763,124 @@ End: Return(v18) //! > lowering_diagnostics + +//! > ========================================================================== + +//! > Mul const fold. + +//! > test_runner_name +test_match_optimizer + +//! > function +fn foo() -> u8 { + 5 * 3 +} + +//! > function_name +foo + +//! > module_code + +//! > semantic_diagnostics + +//! > before +Parameters: +blk0 (root): +Statements: + (v0: core::integer::u8) <- 5 + (v1: core::integer::u8) <- 3 + (v2: core::integer::u16) <- core::integer::u8_wide_mul(v0, v1) +End: + Match(match core::integer::downcast::(v2) { + Option::Some(v3) => blk1, + Option::None => blk2, + }) + +blk1: +Statements: + (v4: (core::integer::u8,)) <- struct_construct(v3) + (v5: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Ok(v4) +End: + Goto(blk3, {v5 -> v6}) + +blk2: +Statements: + (v7: core::array::Array::) <- core::array::array_new::() + (v8: core::felt252) <- 608642107937639184217240406363762551 + (v9: core::array::Array::) <- core::array::array_append::(v7, v8) + (v10: core::panics::Panic) <- struct_construct() + (v11: (core::panics::Panic, core::array::Array::)) <- struct_construct(v10, v9) + (v12: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Err(v11) +End: + Goto(blk3, {v12 -> v6}) + +blk3: +Statements: +End: + Match(match_enum(v6) { + PanicResult::Ok(v13) => blk4, + PanicResult::Err(v14) => blk5, + }) + +blk4: +Statements: + (v17: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Ok(v13) +End: + Return(v17) + +blk5: +Statements: + (v18: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Err(v14) +End: + Return(v18) + +//! > after +Parameters: +blk0 (root): +Statements: + (v0: core::integer::u8) <- 5 + (v1: core::integer::u8) <- 3 + (v2: core::integer::u16) <- 15 + (v3: core::integer::u8) <- 15 +End: + Goto(blk1, {}) + +blk1: +Statements: + (v4: (core::integer::u8,)) <- struct_construct(v3) + (v5: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Ok(v4) +End: + Goto(blk3, {v5 -> v6}) + +blk2: +Statements: + (v7: core::array::Array::) <- core::array::array_new::() + (v8: core::felt252) <- 608642107937639184217240406363762551 + (v9: core::array::Array::) <- core::array::array_append::(v7, v8) + (v10: core::panics::Panic) <- struct_construct() + (v11: (core::panics::Panic, core::array::Array::)) <- struct_construct(v10, v9) + (v12: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Err(v11) +End: + Goto(blk3, {v12 -> v6}) + +blk3: +Statements: +End: + Match(match_enum(v6) { + PanicResult::Ok(v13) => blk4, + PanicResult::Err(v14) => blk5, + }) + +blk4: +Statements: + (v17: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Ok(v13) +End: + Return(v17) + +blk5: +Statements: + (v18: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Err(v14) +End: + Return(v18) + +//! > lowering_diagnostics