diff --git a/src/hints/mod.rs b/src/hints/mod.rs index d737e5c44..4d4c474f9 100644 --- a/src/hints/mod.rs +++ b/src/hints/mod.rs @@ -8,7 +8,6 @@ mod unimplemented; mod vars; use std::collections::{HashMap, HashSet}; -use std::ops::Add; use cairo_vm::hint_processor::builtin_hint_processor::builtin_hint_processor_definition::{ BuiltinHintProcessor, HintProcessorData, @@ -28,6 +27,7 @@ use indoc::indoc; use num_bigint::BigInt; use crate::config::DEFAULT_INPUT_PATH; +use crate::execution::helper::ExecutionHelperWrapper; use crate::io::input::StarknetOsInput; type HintImpl = fn( @@ -38,7 +38,7 @@ type HintImpl = fn( &HashMap, ) -> Result<(), HintError>; -static HINTS: [(&str, HintImpl); 51] = [ +static HINTS: [(&str, HintImpl); 52] = [ // (BREAKPOINT, breakpoint), (STARKNET_OS_INPUT, starknet_os_input), (INITIALIZE_STATE_CHANGES, initialize_state_changes), @@ -90,6 +90,7 @@ static HINTS: [(&str, HintImpl); 51] = [ (syscalls::SEND_MESSAGE_TO_L1, syscalls::send_message_to_l1), (syscalls::STORAGE_READ, syscalls::storage_read), (syscalls::STORAGE_WRITE, syscalls::storage_write), + (SET_AP_TO_ACTUAL_FEE, set_ap_to_actual_fee), (IS_ON_CURVE, is_on_curve), (IS_N_GE_TWO, is_n_ge_two), ]; @@ -204,6 +205,7 @@ pub const STARKNET_OS_INPUT: &str = indoc! {r#" ids.initial_carried_outputs.messages_to_l1 = segments.add_temp_segment() ids.initial_carried_outputs.messages_to_l2 = segments.add_temp_segment()"# }; + pub fn starknet_os_input( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -239,6 +241,7 @@ pub const INITIALIZE_STATE_CHANGES: &str = indoc! {r#" for address, contract in os_input.contracts.items() }"# }; + pub fn initialize_state_changes( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -263,6 +266,7 @@ pub fn initialize_state_changes( } pub const INITIALIZE_CLASS_HASHES: &str = "initial_dict = os_input.class_hash_to_compiled_class_hash"; + pub fn initialize_class_hashes( _vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -281,6 +285,7 @@ pub fn initialize_class_hashes( } pub const SEGMENTS_ADD: &str = "memory[ap] = to_felt_or_relocatable(segments.add())"; + pub fn segments_add( vm: &mut VirtualMachine, _exec_scopes: &mut ExecutionScopes, @@ -293,6 +298,7 @@ pub fn segments_add( } pub const SEGMENTS_ADD_TEMP: &str = "memory[ap] = to_felt_or_relocatable(segments.add_temp_segment())"; + pub fn segments_add_temp( vm: &mut VirtualMachine, _exec_scopes: &mut ExecutionScopes, @@ -304,7 +310,8 @@ pub fn segments_add_temp( insert_value_into_ap(vm, temp_segment) } -pub const TRANSACTIONS_LEN: &str = "memory[fp + 8] = to_felt_or_relocatable(len(os_input.transactions))"; +pub const TRANSACTIONS_LEN: &str = "memory[ap] = to_felt_or_relocatable(len(os_input.transactions))"; + pub fn transactions_len( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -313,10 +320,12 @@ pub fn transactions_len( _constants: &HashMap, ) -> Result<(), HintError> { let os_input = exec_scopes.get::("os_input")?; - vm.insert_value(vm.get_fp().add(8)?, os_input.transactions.len()).map_err(HintError::Memory) + + insert_value_into_ap(vm, os_input.transactions.len()) } pub const BREAKPOINT: &str = "breakpoint()"; + pub fn breakpoint( vm: &mut VirtualMachine, _exec_scopes: &mut ExecutionScopes, @@ -346,6 +355,7 @@ pub fn breakpoint( } pub const IS_N_GE_TWO: &str = "memory[ap] = to_felt_or_relocatable(ids.n >= 2)"; + pub fn is_n_ge_two( vm: &mut VirtualMachine, _exec_scopes: &mut ExecutionScopes, @@ -359,7 +369,30 @@ pub fn is_n_ge_two( Ok(()) } +pub const SET_AP_TO_ACTUAL_FEE: &str = + "memory[ap] = to_felt_or_relocatable(execution_helper.tx_execution_info.actual_fee)"; + +pub fn set_ap_to_actual_fee( + vm: &mut VirtualMachine, + exec_scopes: &mut ExecutionScopes, + _ids_data: &HashMap, + _ap_tracking: &ApTracking, + _constants: &HashMap, +) -> Result<(), HintError> { + let execution_helper = exec_scopes.get::(vars::scopes::EXECUTION_HELPER)?; + let actual_fee = execution_helper + .execution_helper + .borrow() + .tx_execution_info + .as_ref() + .ok_or(HintError::CustomHint("ExecutionHelper should have tx_execution_info".to_owned().into_boxed_str()))? + .actual_fee; + + insert_value_into_ap(vm, Felt252::from(actual_fee.0)) +} + pub const IS_ON_CURVE: &str = "ids.is_on_curve = (y * y) % SECP_P == y_square_int"; + pub fn is_on_curve( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, diff --git a/src/hints/syscalls.rs b/src/hints/syscalls.rs index bb74e8f7f..2ad1facca 100644 --- a/src/hints/syscalls.rs +++ b/src/hints/syscalls.rs @@ -13,6 +13,7 @@ use crate::execution::deprecated_syscall_handler::DeprecatedOsSyscallHandlerWrap use crate::hints::vars; pub const CALL_CONTRACT: &str = "syscall_handler.call_contract(segments=segments, syscall_ptr=ids.syscall_ptr)"; + pub fn call_contract( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -29,6 +30,7 @@ pub fn call_contract( } pub const DELEGATE_CALL: &str = "syscall_handler.delegate_call(segments=segments, syscall_ptr=ids.syscall_ptr)"; + pub fn delegate_call( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -46,6 +48,7 @@ pub fn delegate_call( pub const DELEGATE_L1_HANDLER: &str = "syscall_handler.delegate_l1_handler(segments=segments, syscall_ptr=ids.syscall_ptr)"; + pub fn delegate_l1_handler( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -62,6 +65,7 @@ pub fn delegate_l1_handler( } pub const DEPLOY: &str = "syscall_handler.deploy(segments=segments, syscall_ptr=ids.syscall_ptr)"; + pub fn deploy( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -78,6 +82,7 @@ pub fn deploy( } pub const EMIT_EVENT: &str = "syscall_handler.emit_event(segments=segments, syscall_ptr=ids.syscall_ptr)"; + pub fn emit_event( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -94,6 +99,7 @@ pub fn emit_event( } pub const GET_BLOCK_NUMBER: &str = "syscall_handler.get_block_number(segments=segments, syscall_ptr=ids.syscall_ptr)"; + pub fn get_block_number( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -111,6 +117,7 @@ pub fn get_block_number( pub const GET_BLOCK_TIMESTAMP: &str = "syscall_handler.get_block_timestamp(segments=segments, syscall_ptr=ids.syscall_ptr)"; + pub fn get_block_timestamp( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -128,6 +135,7 @@ pub fn get_block_timestamp( pub const GET_CALLER_ADDRESS: &str = "syscall_handler.get_caller_address(segments=segments, syscall_ptr=ids.syscall_ptr)"; + pub fn get_caller_address( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -145,6 +153,7 @@ pub fn get_caller_address( pub const GET_CONTRACT_ADDRESS: &str = "syscall_handler.get_contract_address(segments=segments, syscall_ptr=ids.syscall_ptr)"; + pub fn get_contract_address( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -162,6 +171,7 @@ pub fn get_contract_address( pub const GET_SEQUENCER_ADDRESS: &str = "syscall_handler.get_sequencer_address(segments=segments, syscall_ptr=ids.syscall_ptr)"; + pub fn get_sequencer_address( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -178,6 +188,7 @@ pub fn get_sequencer_address( } pub const GET_TX_INFO: &str = "syscall_handler.get_tx_info(segments=segments, syscall_ptr=ids.syscall_ptr)"; + pub fn get_tx_info( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -194,6 +205,7 @@ pub fn get_tx_info( } pub const GET_TX_SIGNATURE: &str = "syscall_handler.get_tx_signature(segments=segments, syscall_ptr=ids.syscall_ptr)"; + pub fn get_tx_signature( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -210,6 +222,7 @@ pub fn get_tx_signature( } pub const LIBRARY: &str = "syscall_handler.library_call(segments=segments, syscall_ptr=ids.syscall_ptr)"; + pub fn library_call( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -227,6 +240,7 @@ pub fn library_call( pub const LIBRARY_CALL_L1_HANDLER: &str = "syscall_handler.library_call_l1_handler(segments=segments, syscall_ptr=ids.syscall_ptr)"; + pub fn library_call_l1_handler( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -243,6 +257,7 @@ pub fn library_call_l1_handler( } pub const REPLACE_CLASS: &str = "syscall_handler.replace_class(segments=segments, syscall_ptr=ids.syscall_ptr)"; + pub fn replace_class( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -260,6 +275,7 @@ pub fn replace_class( pub const SEND_MESSAGE_TO_L1: &str = "syscall_handler.send_message_to_l1(segments=segments, syscall_ptr=ids.syscall_ptr)"; + pub fn send_message_to_l1( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -276,6 +292,7 @@ pub fn send_message_to_l1( } pub const STORAGE_READ: &str = "syscall_handler.storage_read(segments=segments, syscall_ptr=ids.syscall_ptr)"; + pub fn storage_read( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -292,6 +309,7 @@ pub fn storage_read( } pub const STORAGE_WRITE: &str = "syscall_handler.storage_write(segments=segments, syscall_ptr=ids.syscall_ptr)"; + pub fn storage_write( vm: &mut VirtualMachine, exec_scopes: &mut ExecutionScopes, @@ -336,39 +354,14 @@ pub fn set_syscall_ptr( #[cfg(test)] mod tests { - use std::sync::Arc; - - use blockifier::block_context::{BlockContext, FeeTokenAddresses, GasPrices}; - use cairo_vm::types::exec_scope::ExecutionScopes; + use blockifier::block_context::BlockContext; use cairo_vm::types::relocatable::Relocatable; use rstest::{fixture, rstest}; - use starknet_api::block::{BlockNumber, BlockTimestamp}; - use starknet_api::core::{ChainId, ContractAddress, PatriciaKey}; - use starknet_api::hash::StarkHash; - use starknet_api::{contract_address, patricia_key}; use super::*; + use crate::hints::tests::tests::block_context; use crate::ExecutionHelperWrapper; - #[fixture] - fn block_context() -> BlockContext { - BlockContext { - chain_id: ChainId("SN_GOERLI".to_string()), - block_number: BlockNumber(1_000_000), - block_timestamp: BlockTimestamp(1_704_067_200), - sequencer_address: contract_address!("0x0"), - fee_token_addresses: FeeTokenAddresses { - eth_fee_token_address: contract_address!("0x1"), - strk_fee_token_address: contract_address!("0x2"), - }, - vm_resource_fee_cost: Arc::new(HashMap::new()), - gas_prices: GasPrices { eth_l1_gas_price: 1, strk_l1_gas_price: 1 }, - invoke_tx_max_n_steps: 1, - validate_max_n_steps: 1, - max_recursion_depth: 50, - } - } - #[fixture] fn exec_scopes(block_context: BlockContext) -> ExecutionScopes { let syscall_ptr = Relocatable::from((0, 0)); diff --git a/src/hints/tests.rs b/src/hints/tests.rs index 8d811d58d..132db0a5e 100644 --- a/src/hints/tests.rs +++ b/src/hints/tests.rs @@ -1,9 +1,15 @@ #[cfg(test)] -mod tests { - use cairo_vm::serde::deserialize_program::ApTracking; - use cairo_vm::types::exec_scope::ExecutionScopes; - use num_bigint::BigInt; - use rstest::rstest; +pub(crate) mod tests { + use std::sync::Arc; + + use blockifier::block_context::{BlockContext, FeeTokenAddresses, GasPrices}; + use blockifier::transaction::objects::TransactionExecutionInfo; + use rstest::{fixture, rstest}; + use starknet_api::block::{BlockNumber, BlockTimestamp}; + use starknet_api::core::{ChainId, ContractAddress, PatriciaKey}; + use starknet_api::hash::StarkHash; + use starknet_api::transaction::Fee; + use starknet_api::{contract_address, patricia_key}; use crate::hints::*; @@ -31,6 +37,67 @@ mod tests { }; } + #[fixture] + pub fn block_context() -> BlockContext { + BlockContext { + chain_id: ChainId("SN_GOERLI".to_string()), + block_number: BlockNumber(1_000_000), + block_timestamp: BlockTimestamp(1_704_067_200), + sequencer_address: contract_address!("0x0"), + fee_token_addresses: FeeTokenAddresses { + eth_fee_token_address: contract_address!("0x1"), + strk_fee_token_address: contract_address!("0x2"), + }, + vm_resource_fee_cost: Arc::new(HashMap::new()), + gas_prices: GasPrices { eth_l1_gas_price: 1, strk_l1_gas_price: 1 }, + invoke_tx_max_n_steps: 1, + validate_max_n_steps: 1, + max_recursion_depth: 50, + } + } + + #[fixture] + fn transaction_execution_info() -> TransactionExecutionInfo { + TransactionExecutionInfo { + validate_call_info: None, + execute_call_info: None, + fee_transfer_call_info: None, + actual_fee: Fee(1234), + actual_resources: Default::default(), + revert_error: None, + } + } + + #[rstest] + fn test_set_ap_to_actual_fee_hint( + block_context: BlockContext, + transaction_execution_info: TransactionExecutionInfo, + ) { + let mut vm = VirtualMachine::new(false); + vm.set_fp(1); + vm.add_memory_segment(); + vm.add_memory_segment(); + + let ids_data = Default::default(); + let ap_tracking = ApTracking::default(); + + let mut exec_scopes = ExecutionScopes::new(); + + // inject txn execution info with a fee for hint to use + let execution_infos = vec![transaction_execution_info]; + let exec_helper = ExecutionHelperWrapper::new(execution_infos, &block_context); + exec_helper.start_tx(None); + exec_scopes.insert_box(vars::scopes::EXECUTION_HELPER, Box::new(exec_helper)); + + set_ap_to_actual_fee(&mut vm, &mut exec_scopes, &ids_data, &ap_tracking, &Default::default()) + .expect("set_ap_to_actual_fee() failed"); + + let ap = vm.get_ap(); + + let fee = vm.get_integer(ap).unwrap().into_owned(); + assert_eq!(fee, 1234.into()); + } + #[test] fn test_is_on_curve() { let mut vm = VirtualMachine::new(false); diff --git a/src/hints/unimplemented.rs b/src/hints/unimplemented.rs index 07d2741e5..841180b1c 100644 --- a/src/hints/unimplemented.rs +++ b/src/hints/unimplemented.rs @@ -94,9 +94,6 @@ const SET_TREE_STRUCTURE: &str = indoc! {r#" ])"# }; -#[allow(unused)] -const SET_AP_TO_ACTUAL_FEE: &str = "memory[ap] = to_felt_or_relocatable(execution_helper.tx_execution_info.actual_fee)"; - #[allow(unused)] const SPLIT_OUTPUT1: &str = indoc! {r#" tmp, ids.output1_low = divmod(ids.output1, 256 ** 7) diff --git a/src/hints/vars.rs b/src/hints/vars.rs index 0fcebda84..469610737 100644 --- a/src/hints/vars.rs +++ b/src/hints/vars.rs @@ -1,4 +1,5 @@ pub mod scopes { + pub const EXECUTION_HELPER: &str = "execution_helper"; pub const SYSCALL_HANDLER: &str = "syscall_handler"; } diff --git a/tests/common/utils.rs b/tests/common/utils.rs index 8b6a4fd27..1d587a9a4 100644 --- a/tests/common/utils.rs +++ b/tests/common/utils.rs @@ -1,8 +1,7 @@ -use std::{env, fs, path}; +use std::{env, path}; use blockifier::execution::contract_class::{ContractClass, ContractClassV0}; use cairo_vm::vm::errors::cairo_run_errors::CairoRunError; -use cairo_vm::vm::vm_core::VirtualMachine; use starknet_api::core::ContractAddress; use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContractClass;