Skip to content
This repository has been archived by the owner on Dec 15, 2023. It is now read-only.

Fix estimate fee #88

Merged
merged 3 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions starknet_devnet/starknet_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,14 @@
from starkware.starkware_utils.error_handling import StarkException
from starkware.starknet.services.api.feeder_gateway.block_hash import calculate_block_hash
from starkware.starknet.business_logic.transaction_fee import calculate_tx_fee_by_cairo_usage
from starkware.starknet.services.api.contract_definition import EntryPointType
from starkware.starknet.definitions import constants

from .origin import NullOrigin, Origin
from .general_config import DEFAULT_GENERAL_CONFIG
from .util import (
Choice, StarknetDevnetException, TxStatus, DummyExecutionInfo,
fixed_length_hex, enable_pickling, generate_state_update
)
from .contract_wrapper import ContractWrapper
from .contract_wrapper import ContractWrapper, call_internal_tx
from .transaction_wrapper import TransactionWrapper, DeployTransactionWrapper, InvokeTransactionWrapper
from .postman_wrapper import LocalPostmanWrapper
from .constants import FAILURE_REASON_KEY
Expand Down Expand Up @@ -555,21 +553,9 @@ def get_state_update(self, block_hash=None, block_number=None):
async def calculate_actual_fee(self, external_tx: InvokeFunction):
"""Calculates actual fee"""
state = await self.__get_state()
internal_tx = InternalInvokeFunction.create(
contract_address=external_tx.contract_address,
entry_point_selector=external_tx.entry_point_selector,
max_fee=external_tx.max_fee,
entry_point_type=EntryPointType.EXTERNAL,
calldata=external_tx.calldata,
signature=external_tx.signature,
nonce=None,
chain_id=state.general_config.chain_id.value,
# Need to set to 0 as it will be invoked in apply_state_updates
version=constants.TRANSACTION_VERSION,
)
internal_tx = InternalInvokeFunction.from_external_query_tx(external_tx, state.general_config)

state_copy = state.state._copy() # pylint: disable=protected-access
execution_info = await internal_tx.apply_state_updates(state_copy, state.general_config)
execution_info = await call_internal_tx(state.copy(), internal_tx)

actual_fee = calculate_tx_fee_by_cairo_usage(
general_config=state.general_config,
Expand Down
22 changes: 15 additions & 7 deletions test/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from starkware.cairo.common.hash_state import compute_hash_on_elements
from starkware.crypto.signature.signature import private_to_stark_key, sign
from starkware.starknet.public.abi import get_selector_from_name
from starkware.starknet.definitions.constants import TRANSACTION_VERSION, QUERY_VERSION

from .util import deploy, call, invoke, estimate_fee

Expand All @@ -16,8 +17,6 @@
ACCOUNT_PATH = f"{ACCOUNT_ARTIFACTS_PATH}/{ACCOUNT_AUTHOR}/{ACCOUNT_VERSION}/Account.cairo/Account.json"
ACCOUNT_ABI_PATH = f"{ACCOUNT_ARTIFACTS_PATH}/{ACCOUNT_AUTHOR}/{ACCOUNT_VERSION}/Account.cairo/Account_abi.json"

TRANSACTION_VERSION = 0

PRIVATE_KEY = 123456789987654321
PUBLIC_KEY = private_to_stark_key(PRIVATE_KEY)

Expand All @@ -43,7 +42,7 @@ def str_to_felt(text):
"""Converts string to felt."""
return int.from_bytes(bytes(text, "ascii"), "big")

def hash_multicall(sender, calls, nonce, max_fee):
def hash_multicall(sender, calls, nonce, max_fee, version):
"""desc"""
hash_array = []

Expand All @@ -57,7 +56,7 @@ def hash_multicall(sender, calls, nonce, max_fee):
compute_hash_on_elements(hash_array),
nonce,
max_fee,
TRANSACTION_VERSION
version
])


Expand Down Expand Up @@ -89,7 +88,7 @@ def adapt_inputs(execute_calldata):
"""Get stringified inputs from execute_calldata."""
return [str(v) for v in execute_calldata]

def get_execute_args(calls, account_address, nonce=None, max_fee=0):
def get_execute_args(calls, account_address, nonce=None, max_fee=0, version=TRANSACTION_VERSION):
"""Returns signature and execute calldata"""

if nonce is None:
Expand All @@ -99,7 +98,11 @@ def get_execute_args(calls, account_address, nonce=None, max_fee=0):
calls_with_selector = [
(call[0], get_selector_from_name(call[1]), call[2]) for call in calls]
message_hash = hash_multicall(
int(account_address, 16), calls_with_selector, int(nonce), max_fee
sender=int(account_address, 16),
calls=calls_with_selector,
nonce=int(nonce),
max_fee=max_fee,
version=version
)
signature = get_signature(message_hash)

Expand All @@ -111,7 +114,12 @@ def get_execute_args(calls, account_address, nonce=None, max_fee=0):

def get_estimated_fee(calls, account_address, nonce=None):
"""Get estmated fee."""
signature, execute_calldata = get_execute_args(calls, account_address, nonce)
signature, execute_calldata = get_execute_args(
calls=calls,
account_address=account_address,
nonce=nonce,
version=QUERY_VERSION
)

return estimate_fee(
"__execute__",
Expand Down