From d923eb7feac921d41fdb8baf0477a13d812200b4 Mon Sep 17 00:00:00 2001 From: Ori Ziv Date: Thu, 25 Jul 2024 12:28:02 +0300 Subject: [PATCH] Added const-folding for the downcast libfunc. commit-id:74d741f8 --- .../src/optimizations/const_folding.rs | 29 ++ .../src/optimizations/test_data/const_folding | 485 ++++++++++++++++++ .../cairo-lang-lowering/src/test_data/match | 7 +- 3 files changed, 517 insertions(+), 4 deletions(-) diff --git a/crates/cairo-lang-lowering/src/optimizations/const_folding.rs b/crates/cairo-lang-lowering/src/optimizations/const_folding.rs index 2441cdd2bea..4214fd57c81 100644 --- a/crates/cairo-lang-lowering/src/optimizations/const_folding.rs +++ b/crates/cairo-lang-lowering/src/optimizations/const_folding.rs @@ -246,6 +246,31 @@ impl<'a> ConstFoldingContext<'a> { FlatBlockEnd::Goto(arm.block_id, Default::default()), ) }); + } else if let Some(extrn) = info.function.get_extern(self.db) { + if extrn == self.downcast { + let input_var = info.inputs[0].var_id; + let Some(VarInfo::Const(ConstValue::Int(value, _))) = self.var_info.get(&input_var) + else { + return None; + }; + let success_output = info.arms[0].var_ids[0]; + let ty = self.variables[success_output].ty; + return Some( + if corelib::validate_literal(self.db.upcast(), ty, value.clone()).is_ok() { + let value = ConstValue::Int(value.clone(), ty); + self.var_info.insert(success_output, VarInfo::Const(value.clone())); + ( + Some(Statement::Const(StatementConst { + value, + output: success_output, + })), + FlatBlockEnd::Goto(info.arms[0].block_id, Default::default()), + ) + } else { + (None, FlatBlockEnd::Goto(info.arms[1].block_id, Default::default())) + }, + ); + } } None } @@ -328,6 +353,8 @@ struct LibfuncInfo<'a> { into_box: ExternFunctionId, /// The `upcast` libfunc. upcast: ExternFunctionId, + /// The `downcast` libfunc. + downcast: ExternFunctionId, /// The `storage_base_address_from_felt252` libfunc. storage_base_address_from_felt252: FunctionId, /// The set of functions that check if a number is zero. @@ -343,6 +370,7 @@ impl<'a> LibfuncInfo<'a> { let into_box = box_module.extern_function_id("into_box"); let integer_module = core.submodule("integer"); let upcast = integer_module.extern_function_id("upcast"); + let downcast = integer_module.extern_function_id("downcast"); let starknet_module = core.submodule("starknet"); let storage_access_module = starknet_module.submodule("storage_access"); let storage_base_address_from_felt252 = @@ -356,6 +384,7 @@ impl<'a> LibfuncInfo<'a> { felt_sub, into_box, upcast, + downcast, storage_base_address_from_felt252, nz_fns, storage_access_module, diff --git a/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding b/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding index cdf378288fd..982e3683a74 100644 --- a/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding +++ b/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding @@ -766,6 +766,491 @@ End: //! > ========================================================================== +//! > Downcast const success. + +//! > test_runner_name +test_match_optimizer + +//! > function +fn foo(x: u64) -> u64 { + x / 1_u128.try_into().unwrap() +} + +//! > function_name +foo + +//! > module_code + +//! > semantic_diagnostics + +//! > before +Parameters: v0: core::integer::u64 +blk0 (root): +Statements: + (v1: core::integer::u128) <- 1 +End: + Match(match core::integer::downcast::(v1) { + Option::Some(v2) => blk1, + Option::None => blk2, + }) + +blk1: +Statements: + (v3: core::option::Option::) <- Option::Some(v2) +End: + Goto(blk3, {v3 -> v4}) + +blk2: +Statements: + (v5: ()) <- struct_construct() + (v6: core::option::Option::) <- Option::None(v5) +End: + Goto(blk3, {v6 -> v4}) + +blk3: +Statements: +End: + Match(match_enum(v4) { + Option::Some(v7) => blk4, + Option::None(v8) => blk5, + }) + +blk4: +Statements: + (v9: (core::integer::u64,)) <- struct_construct(v7) + (v10: core::panics::PanicResult::<(core::integer::u64,)>) <- PanicResult::Ok(v9) +End: + Goto(blk6, {v10 -> v11}) + +blk5: +Statements: + (v12: core::array::Array::) <- core::array::array_new::() + (v13: core::felt252) <- 29721761890975875353235833581453094220424382983267374 + (v14: core::array::Array::) <- core::array::array_append::(v12, v13) + (v15: core::panics::Panic) <- struct_construct() + (v16: (core::panics::Panic, core::array::Array::)) <- struct_construct(v15, v14) + (v17: core::panics::PanicResult::<(core::integer::u64,)>) <- PanicResult::Err(v16) +End: + Goto(blk6, {v17 -> v11}) + +blk6: +Statements: +End: + Match(match_enum(v11) { + PanicResult::Ok(v18) => blk7, + PanicResult::Err(v19) => blk13, + }) + +blk7: +Statements: + (v20: core::integer::u64) <- struct_destructure(v18) +End: + Match(match core::integer::u64_is_zero(v20) { + IsZeroResult::Zero => blk8, + IsZeroResult::NonZero(v21) => blk9, + }) + +blk8: +Statements: + (v22: core::array::Array::) <- core::array::array_new::() + (v23: core::felt252) <- 5420154128225384396790819266608 + (v24: core::array::Array::) <- core::array::array_append::(v22, v23) + (v25: core::panics::Panic) <- struct_construct() + (v26: (core::panics::Panic, core::array::Array::)) <- struct_construct(v25, v24) + (v27: core::panics::PanicResult::<(core::integer::u64,)>) <- PanicResult::Err(v26) +End: + Goto(blk10, {v27 -> v28}) + +blk9: +Statements: + (v29: core::integer::u64, v30: core::integer::u64) <- core::integer::u64_safe_divmod(v0, v21) + (v31: (core::integer::u64,)) <- struct_construct(v29) + (v32: core::panics::PanicResult::<(core::integer::u64,)>) <- PanicResult::Ok(v31) +End: + Goto(blk10, {v32 -> v28}) + +blk10: +Statements: +End: + Match(match_enum(v28) { + PanicResult::Ok(v33) => blk11, + PanicResult::Err(v34) => blk12, + }) + +blk11: +Statements: + (v37: core::panics::PanicResult::<(core::integer::u64,)>) <- PanicResult::Ok(v33) +End: + Return(v37) + +blk12: +Statements: + (v38: core::panics::PanicResult::<(core::integer::u64,)>) <- PanicResult::Err(v34) +End: + Return(v38) + +blk13: +Statements: + (v39: core::panics::PanicResult::<(core::integer::u64,)>) <- PanicResult::Err(v19) +End: + Return(v39) + +//! > after +Parameters: v0: core::integer::u64 +blk0 (root): +Statements: + (v1: core::integer::u128) <- 1 + (v2: core::integer::u64) <- 1 +End: + Goto(blk1, {}) + +blk1: +Statements: + (v3: core::option::Option::) <- Option::Some(v2) +End: + Goto(blk3, {v3 -> v4}) + +blk2: +Statements: + (v5: ()) <- struct_construct() + (v6: core::option::Option::) <- Option::None(v5) +End: + Goto(blk3, {v6 -> v4}) + +blk3: +Statements: +End: + Match(match_enum(v4) { + Option::Some(v7) => blk4, + Option::None(v8) => blk5, + }) + +blk4: +Statements: + (v9: (core::integer::u64,)) <- struct_construct(v7) + (v10: core::panics::PanicResult::<(core::integer::u64,)>) <- PanicResult::Ok(v9) +End: + Goto(blk6, {v10 -> v11}) + +blk5: +Statements: + (v12: core::array::Array::) <- core::array::array_new::() + (v13: core::felt252) <- 29721761890975875353235833581453094220424382983267374 + (v14: core::array::Array::) <- core::array::array_append::(v12, v13) + (v15: core::panics::Panic) <- struct_construct() + (v16: (core::panics::Panic, core::array::Array::)) <- struct_construct(v15, v14) + (v17: core::panics::PanicResult::<(core::integer::u64,)>) <- PanicResult::Err(v16) +End: + Goto(blk6, {v17 -> v11}) + +blk6: +Statements: +End: + Match(match_enum(v11) { + PanicResult::Ok(v18) => blk7, + PanicResult::Err(v19) => blk13, + }) + +blk7: +Statements: + (v20: core::integer::u64) <- struct_destructure(v18) +End: + Match(match core::integer::u64_is_zero(v20) { + IsZeroResult::Zero => blk8, + IsZeroResult::NonZero(v21) => blk9, + }) + +blk8: +Statements: + (v22: core::array::Array::) <- core::array::array_new::() + (v23: core::felt252) <- 5420154128225384396790819266608 + (v24: core::array::Array::) <- core::array::array_append::(v22, v23) + (v25: core::panics::Panic) <- struct_construct() + (v26: (core::panics::Panic, core::array::Array::)) <- struct_construct(v25, v24) + (v27: core::panics::PanicResult::<(core::integer::u64,)>) <- PanicResult::Err(v26) +End: + Goto(blk10, {v27 -> v28}) + +blk9: +Statements: + (v29: core::integer::u64, v30: core::integer::u64) <- core::integer::u64_safe_divmod(v0, v21) + (v31: (core::integer::u64,)) <- struct_construct(v29) + (v32: core::panics::PanicResult::<(core::integer::u64,)>) <- PanicResult::Ok(v31) +End: + Goto(blk10, {v32 -> v28}) + +blk10: +Statements: +End: + Match(match_enum(v28) { + PanicResult::Ok(v33) => blk11, + PanicResult::Err(v34) => blk12, + }) + +blk11: +Statements: + (v37: core::panics::PanicResult::<(core::integer::u64,)>) <- PanicResult::Ok(v33) +End: + Return(v37) + +blk12: +Statements: + (v38: core::panics::PanicResult::<(core::integer::u64,)>) <- PanicResult::Err(v34) +End: + Return(v38) + +blk13: +Statements: + (v39: core::panics::PanicResult::<(core::integer::u64,)>) <- PanicResult::Err(v19) +End: + Return(v39) + +//! > lowering_diagnostics + +//! > ========================================================================== + +//! > Downcast const failure. + +//! > test_runner_name +test_match_optimizer + +//! > function +fn foo(x: u8) -> u8 { + x / 300_u16.try_into().unwrap() +} + +//! > function_name +foo + +//! > module_code + +//! > semantic_diagnostics + +//! > before +Parameters: v0: core::integer::u8 +blk0 (root): +Statements: + (v1: core::integer::u16) <- 300 +End: + Match(match core::integer::downcast::(v1) { + Option::Some(v2) => blk1, + Option::None => blk2, + }) + +blk1: +Statements: + (v3: core::option::Option::) <- Option::Some(v2) +End: + Goto(blk3, {v3 -> v4}) + +blk2: +Statements: + (v5: ()) <- struct_construct() + (v6: core::option::Option::) <- Option::None(v5) +End: + Goto(blk3, {v6 -> v4}) + +blk3: +Statements: +End: + Match(match_enum(v4) { + Option::Some(v7) => blk4, + Option::None(v8) => blk5, + }) + +blk4: +Statements: + (v9: (core::integer::u8,)) <- struct_construct(v7) + (v10: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Ok(v9) +End: + Goto(blk6, {v10 -> v11}) + +blk5: +Statements: + (v12: core::array::Array::) <- core::array::array_new::() + (v13: core::felt252) <- 29721761890975875353235833581453094220424382983267374 + (v14: core::array::Array::) <- core::array::array_append::(v12, v13) + (v15: core::panics::Panic) <- struct_construct() + (v16: (core::panics::Panic, core::array::Array::)) <- struct_construct(v15, v14) + (v17: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Err(v16) +End: + Goto(blk6, {v17 -> v11}) + +blk6: +Statements: +End: + Match(match_enum(v11) { + PanicResult::Ok(v18) => blk7, + PanicResult::Err(v19) => blk13, + }) + +blk7: +Statements: + (v20: core::integer::u8) <- struct_destructure(v18) +End: + Match(match core::integer::u8_is_zero(v20) { + IsZeroResult::Zero => blk8, + IsZeroResult::NonZero(v21) => blk9, + }) + +blk8: +Statements: + (v22: core::array::Array::) <- core::array::array_new::() + (v23: core::felt252) <- 5420154128225384396790819266608 + (v24: core::array::Array::) <- core::array::array_append::(v22, v23) + (v25: core::panics::Panic) <- struct_construct() + (v26: (core::panics::Panic, core::array::Array::)) <- struct_construct(v25, v24) + (v27: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Err(v26) +End: + Goto(blk10, {v27 -> v28}) + +blk9: +Statements: + (v29: core::integer::u8, v30: core::integer::u8) <- core::integer::u8_safe_divmod(v0, v21) + (v31: (core::integer::u8,)) <- struct_construct(v29) + (v32: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Ok(v31) +End: + Goto(blk10, {v32 -> v28}) + +blk10: +Statements: +End: + Match(match_enum(v28) { + PanicResult::Ok(v33) => blk11, + PanicResult::Err(v34) => blk12, + }) + +blk11: +Statements: + (v37: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Ok(v33) +End: + Return(v37) + +blk12: +Statements: + (v38: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Err(v34) +End: + Return(v38) + +blk13: +Statements: + (v39: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Err(v19) +End: + Return(v39) + +//! > after +Parameters: v0: core::integer::u8 +blk0 (root): +Statements: + (v1: core::integer::u16) <- 300 +End: + Goto(blk2, {}) + +blk1: +Statements: + (v3: core::option::Option::) <- Option::Some(v2) +End: + Goto(blk3, {v3 -> v4}) + +blk2: +Statements: + (v5: ()) <- struct_construct() + (v6: core::option::Option::) <- Option::None(v5) +End: + Goto(blk3, {v6 -> v4}) + +blk3: +Statements: +End: + Match(match_enum(v4) { + Option::Some(v7) => blk4, + Option::None(v8) => blk5, + }) + +blk4: +Statements: + (v9: (core::integer::u8,)) <- struct_construct(v7) + (v10: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Ok(v9) +End: + Goto(blk6, {v10 -> v11}) + +blk5: +Statements: + (v12: core::array::Array::) <- core::array::array_new::() + (v13: core::felt252) <- 29721761890975875353235833581453094220424382983267374 + (v14: core::array::Array::) <- core::array::array_append::(v12, v13) + (v15: core::panics::Panic) <- struct_construct() + (v16: (core::panics::Panic, core::array::Array::)) <- struct_construct(v15, v14) + (v17: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Err(v16) +End: + Goto(blk6, {v17 -> v11}) + +blk6: +Statements: +End: + Match(match_enum(v11) { + PanicResult::Ok(v18) => blk7, + PanicResult::Err(v19) => blk13, + }) + +blk7: +Statements: + (v20: core::integer::u8) <- struct_destructure(v18) +End: + Match(match core::integer::u8_is_zero(v20) { + IsZeroResult::Zero => blk8, + IsZeroResult::NonZero(v21) => blk9, + }) + +blk8: +Statements: + (v22: core::array::Array::) <- core::array::array_new::() + (v23: core::felt252) <- 5420154128225384396790819266608 + (v24: core::array::Array::) <- core::array::array_append::(v22, v23) + (v25: core::panics::Panic) <- struct_construct() + (v26: (core::panics::Panic, core::array::Array::)) <- struct_construct(v25, v24) + (v27: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Err(v26) +End: + Goto(blk10, {v27 -> v28}) + +blk9: +Statements: + (v29: core::integer::u8, v30: core::integer::u8) <- core::integer::u8_safe_divmod(v0, v21) + (v31: (core::integer::u8,)) <- struct_construct(v29) + (v32: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Ok(v31) +End: + Goto(blk10, {v32 -> v28}) + +blk10: +Statements: +End: + Match(match_enum(v28) { + PanicResult::Ok(v33) => blk11, + PanicResult::Err(v34) => blk12, + }) + +blk11: +Statements: + (v37: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Ok(v33) +End: + Return(v37) + +blk12: +Statements: + (v38: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Err(v34) +End: + Return(v38) + +blk13: +Statements: + (v39: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Err(v19) +End: + Return(v39) + +//! > lowering_diagnostics + +//! > ========================================================================== + //! > StorageBaseAddress const. //! > test_runner_name diff --git a/crates/cairo-lang-lowering/src/test_data/match b/crates/cairo-lang-lowering/src/test_data/match index cb9b03376c1..0761de2af46 100644 --- a/crates/cairo-lang-lowering/src/test_data/match +++ b/crates/cairo-lang-lowering/src/test_data/match @@ -1496,8 +1496,8 @@ Parameters: test_function_lowering //! > function -fn foo() -> felt252 { - match 5_u32 { +fn foo(v: u32) -> felt252 { + match v { 0 => 1, 1 => 2, 2 => 3, @@ -1519,10 +1519,9 @@ foo //! > lowering_diagnostics //! > lowering_flat -Parameters: v0: core::RangeCheck +Parameters: v0: core::RangeCheck, v1: core::integer::u32 blk0 (root): Statements: - (v1: core::integer::u32) <- 5 End: Match(match core::integer::downcast::>(v0, v1) { Option::Some(v2, v3) => blk1,