Skip to content
This repository has been archived by the owner on Dec 15, 2023. It is now read-only.

Bounce back invalid transactions #518

Merged
merged 9 commits into from
Aug 11, 2023
Merged
9 changes: 4 additions & 5 deletions starknet_devnet/blueprints/rpc/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Callable, Dict, List, Tuple, Union

from flask import Blueprint, request
from starkware.starkware_utils.error_handling import StarkException

from starknet_devnet.blueprints.rpc.blocks import (
block_hash_and_number,
Expand All @@ -32,9 +33,9 @@
from starknet_devnet.blueprints.rpc.state import get_state_update
from starknet_devnet.blueprints.rpc.storage import get_storage_at
from starknet_devnet.blueprints.rpc.structures.types import (
GATEWAY_TO_RPC_ERROR,
PredefinedRpcErrorCode,
RpcError,
map_gateway_to_rpc_error_dict,
)
from starknet_devnet.blueprints.rpc.transactions import (
add_declare_transaction,
Expand All @@ -48,7 +49,6 @@
simulate_transaction,
)
from starknet_devnet.blueprints.rpc.utils import rpc_error, rpc_response
from starknet_devnet.util import StarknetDevnetException

methods = {
"getBlockWithTxHashes": get_block_with_tx_hashes,
Expand Down Expand Up @@ -110,9 +110,8 @@ async def base_route():
code=PredefinedRpcErrorCode.INTERNAL_ERROR.value,
message=str(error),
)
except StarknetDevnetException as ex:
default_error = PredefinedRpcErrorCode.INTERNAL_ERROR
rpc_error_dict = GATEWAY_TO_RPC_ERROR.get(ex.code, default_error)
except StarkException as ex:
rpc_error_dict = map_gateway_to_rpc_error_dict(exception=ex)
return rpc_error(message_id=message_id, **rpc_error_dict)

return rpc_response(message_id=message_id, content=result)
Expand Down
35 changes: 32 additions & 3 deletions starknet_devnet/blueprints/rpc/structures/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from starkware.starknet.definitions.error_codes import StarknetErrorCode
from starkware.starknet.definitions.transaction_type import TransactionType
from starkware.starknet.services.api.feeder_gateway.response_objects import BlockStatus
from starkware.starkware_utils.error_handling import StarkException
from typing_extensions import Literal, TypedDict

from ..rpc_spec import RPC_SPECIFICATION
Expand Down Expand Up @@ -125,6 +126,34 @@ def _combine_rpc_errors():

RPC_ERRORS = _combine_rpc_errors()

GATEWAY_TO_RPC_ERROR = {
StarknetErrorCode.BLOCK_NOT_FOUND: RPC_ERRORS["BLOCK_NOT_FOUND"],
}

def map_gateway_to_rpc_error_dict(exception: StarkException) -> RpcError:
"""
JSON-RPC cannot work with raw StarkExceptions, they have to be properly mapped
Contains errors from rpc spec 0.4.0 - necessary for proper mapping of errors
from Starknet 0.12.1 (mostly validation related). Those custom definitions be
removed if RPC 0.4.0 support is added - they be supported via the spec file.
"""
return {
StarknetErrorCode.BLOCK_NOT_FOUND: RPC_ERRORS["BLOCK_NOT_FOUND"],
StarknetErrorCode.INSUFFICIENT_MAX_FEE: {
"code": 53,
"message": "Max fee is smaller than the minimal transaction cost (validation plus fee transfer)",
},
StarknetErrorCode.INSUFFICIENT_ACCOUNT_BALANCE: {
"code": 54,
"message": "Account balance is smaller than the transaction's max_fee",
},
StarknetErrorCode.VALIDATE_FAILURE: {
"code": 55,
"message": "Account validation failed",
},
StarknetErrorCode.INVALID_TRANSACTION_NONCE: {
"code": 52,
"message": "Invalid transaction nonce",
},
StarknetErrorCode.UNDECLARED_CLASS: RPC_ERRORS["CLASS_HASH_NOT_FOUND"],
}.get(exception.code) or {
"code": PredefinedRpcErrorCode.INTERNAL_ERROR.value,
"message": f"Internal error occurred: {exception}",
}
160 changes: 135 additions & 25 deletions starknet_devnet/starknet_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,34 @@
This module introduces `StarknetWrapper`, a wrapper class of
starkware.starknet.testing.starknet.Starknet.
"""
import asyncio
import pprint
from copy import deepcopy
from types import TracebackType
from typing import Dict, List, Optional, Set, Tuple, Type, Union

import cloudpickle as pickle
from starkware.starknet.business_logic.state.state import BlockInfo, CachedState
from starkware.starknet.business_logic.execution.objects import (
ExecutionResourcesManager,
ResourcesMapping,
)
from starkware.starknet.business_logic.state.state import (
BlockInfo,
CachedState,
StateSyncifier,
UpdatesTrackerState,
)
from starkware.starknet.business_logic.state.storage_domain import StorageDomain
from starkware.starknet.business_logic.transaction.fee import calculate_tx_fee
from starkware.starknet.business_logic.transaction.objects import (
CallInfo,
InternalDeclare,
InternalAccountTransaction,
InternalDeploy,
InternalDeployAccount,
InternalInvokeFunction,
InternalL1Handler,
InternalTransaction,
TransactionExecutionInfo,
)
from starkware.starknet.business_logic.utils import calculate_tx_resources
from starkware.starknet.core.os.contract_address.contract_address import (
calculate_contract_address_from_hash,
)
Expand Down Expand Up @@ -55,6 +64,7 @@
TransactionTrace,
)
from starkware.starknet.services.api.gateway.transaction import (
AccountTransaction,
Declare,
DeployAccount,
DeprecatedDeclare,
Expand All @@ -69,7 +79,11 @@
from starkware.starknet.third_party.open_zeppelin.starknet_contracts import (
account_contract as oz_account_class,
)
from starkware.starkware_utils.error_handling import StarkErrorCode, StarkException
from starkware.starkware_utils.error_handling import (
StarkErrorCode,
StarkException,
stark_assert_le,
)

from .accounts import Accounts
from .block_info_generator import BlockInfoGenerator
Expand Down Expand Up @@ -110,6 +124,7 @@
get_storage_diffs,
group_classes_by_version,
logger,
stark_assert_call_succeeded,
warn,
)

Expand Down Expand Up @@ -352,9 +367,6 @@ async def declare(

state = self.get_state()
async with self.__get_transaction_handler(external_tx) as tx_handler:
tx_handler.internal_tx = InternalDeclare.from_external(
external_tx, state.general_config
)
# extract class hash here if execution later fails
class_hash = tx_handler.internal_tx.class_hash

Expand Down Expand Up @@ -413,11 +425,13 @@ def _update_block_number(self):
)
return block_info

def __get_transaction_handler(self, external_tx=None):
def __get_transaction_handler( # pylint: disable=too-many-statements
self, external_tx: Optional[AccountTransaction] = None
):
class TransactionHandler:
"""Class for with-blocks in transactions"""

internal_tx: InternalTransaction
internal_tx: Optional[InternalAccountTransaction] = None
execution_info: TransactionExecutionInfo = TransactionExecutionInfo.empty()
internal_calls: List[CallInfo] = []
deployed_contracts: List[ContractAddressHashPair] = []
Expand All @@ -428,24 +442,128 @@ class TransactionHandler:
def __init__(self, starknet_wrapper: StarknetWrapper):
self.starknet_wrapper = starknet_wrapper
self.preserved_block_info = starknet_wrapper._update_block_number()
if external_tx:
self._validate_fee(external_tx)
self.internal_tx = InternalAccountTransaction.from_external(
external_tx, starknet_wrapper.get_state().general_config
)

def _check_tx_fee(self, transaction):
if not hasattr(transaction, "max_fee"):
return
def _check_nonce(self, state: UpdatesTrackerState):
nonce = state.get_nonce_at(
storage_domain=StorageDomain.ON_CHAIN,
contract_address=self.internal_tx.sender_address,
)
# BACKWARD-COMPATIBILITY.
tx_nonce = (
0 if self.internal_tx.nonce is None else self.internal_tx.nonce
)
stark_assert_le(
nonce,
tx_nonce,
code=StarknetErrorCode.INVALID_TRANSACTION_NONCE,
message="Transaction's nonce must be greater than or equal to the last known nonce.",
)

def _check_balance(self, state: UpdatesTrackerState):
balance = state.get_fee_token_balance(
storage_domain=StorageDomain.ON_CHAIN,
contract_address=self.internal_tx.sender_address,
fee_token_address=self.starknet_wrapper.fee_token.address,
)
stark_assert_le(
self.internal_tx.max_fee,
balance,
code=StarknetErrorCode.INSUFFICIENT_ACCOUNT_BALANCE,
message="Account balance must be greater or equal to the transaction's max_fee.",
)

def _validate(self, state: UpdatesTrackerState) -> ResourcesMapping:
if isinstance(self.internal_tx, InternalDeployAccount):
# Run the entire transaction since a constructor call must precede the `validate`.
tx_execution_info = self.internal_tx.apply_concurrent_changes(
state=state,
general_config=self.starknet_wrapper.get_state().general_config,
)
for call_info in tx_execution_info.non_optional_calls:
stark_assert_call_succeeded(call_info=call_info)

actual_resources = tx_execution_info.actual_resources
else:
resources_manager = ExecutionResourcesManager.empty()
validate_info, _ = self.internal_tx.run_validate_entrypoint(
state=state,
general_config=self.starknet_wrapper.get_state().general_config,
resources_manager=resources_manager,
remaining_gas=self.internal_tx.get_initial_gas(),
)
# Check can be removed when v0 transactions are disabled
if self.internal_tx.version > 0:
assert (
validate_info is not None
), "validate_info must be not None for version > 0."
stark_assert_call_succeeded(call_info=validate_info)

actual_resources = calculate_tx_resources(
state=state,
resources_manager=resources_manager,
call_infos=[validate_info],
tx_type=self.internal_tx.tx_type,
fee_token_address=self.starknet_wrapper.fee_token.address,
is_nonce_increment=self.internal_tx.version > 0,
sender_address=self.internal_tx.sender_address,
)

return actual_resources

def _check_validation_fee(
self, state: UpdatesTrackerState, actual_resources: ResourcesMapping
):
# Check that max_fee is high enough to pay for the validation.
actual_fee = calculate_tx_fee(
gas_price=state.block_info.gas_price,
general_config=self.starknet_wrapper.get_state().general_config,
resources=actual_resources,
)

stark_assert_le(
actual_fee,
self.internal_tx.max_fee,
code=StarknetErrorCode.INSUFFICIENT_MAX_FEE,
message="Max fee must be greater or equal to the validation's actual fee.",
)

def _validate_fee(self, external_tx: AccountTransaction):
if (
transaction.version != LEGACY_TX_VERSION
and transaction.max_fee == 0
external_tx.version != LEGACY_TX_VERSION
and external_tx.max_fee == 0
and not self.starknet_wrapper.config.allow_max_fee_zero
):
raise StarknetDevnetException(
code=StarknetErrorCode.OUT_OF_RANGE_FEE,
message="max_fee == 0 is not supported.",
message="max_fee must be bigger than 0.",
)

async def __aenter__(self):
self._check_tx_fee(external_tx)
if self.internal_tx:
state = self.starknet_wrapper.get_state().state._copy()
loop = asyncio.get_running_loop()
state = UpdatesTrackerState(
state=StateSyncifier(async_state=state, loop=loop)
)
await asyncio.to_thread(
self._inner_perform_state_related_validations, state=state
)
return self

def _inner_perform_state_related_validations(
self, state: UpdatesTrackerState
):
self._check_nonce(state)
self._check_balance(state)
validation_resources = self._validate(state)
if self.internal_tx.max_fee:
self._check_validation_fee(state, validation_resources)

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
Expand Down Expand Up @@ -535,7 +653,6 @@ async def __aexit__(
async def deploy_account(self, external_tx: DeployAccount):
"""Deploys account and returns (address, tx_hash)"""

state = self.get_state()
account_address = calculate_contract_address_from_hash(
salt=external_tx.contract_address_salt,
class_hash=external_tx.class_hash,
Expand All @@ -546,10 +663,6 @@ async def deploy_account(self, external_tx: DeployAccount):
async with self.__get_transaction_handler(
external_tx=external_tx
) as tx_handler:
tx_handler.internal_tx = InternalDeployAccount.from_external(
external_tx, state.general_config
)

tx_handler.execution_info = await self.__deploy(tx_handler.internal_tx)
tx_handler.internal_calls = (
tx_handler.execution_info.call_info.internal_calls
Expand All @@ -566,9 +679,6 @@ async def invoke(self, external_tx: InvokeFunction):
async with self.__get_transaction_handler(
external_tx=external_tx
) as tx_handler:
tx_handler.internal_tx = InternalInvokeFunction.from_external(
external_tx, state.general_config
)
tx_handler.execution_info = await state.execute_tx(tx_handler.internal_tx)
tx_handler.internal_calls = (
tx_handler.execution_info.call_info.internal_calls
Expand Down
28 changes: 27 additions & 1 deletion starknet_devnet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,24 @@
from typing import Dict, List, Set, Tuple

from flask import request
from services.everest.definitions.fields import format_felt_list
from starkware.starknet.business_logic.state.state import CachedState
from starkware.starknet.business_logic.state.storage_domain import StorageDomain
from starkware.starknet.business_logic.transaction.objects import CallInfo
from starkware.starknet.definitions.error_codes import StarknetErrorCode
from starkware.starknet.public.abi import SELECTOR_TO_NAME
from starkware.starknet.services.api.feeder_gateway.response_objects import (
ClassHashPair,
ContractAddressHashPair,
FeeEstimationInfo,
StorageEntry,
)
from starkware.starknet.testing.contract import StarknetContract
from starkware.starkware_utils.error_handling import StarkErrorCode, StarkException
from starkware.starkware_utils.error_handling import (
StarkErrorCode,
StarkException,
stark_assert,
)


def parse_hex_string(arg: str) -> int:
Expand Down Expand Up @@ -343,3 +350,22 @@ async def wrapper(*args, **kwargs):
return wrapper

return decorator


def stark_assert_call_succeeded(call_info: CallInfo):
"""Assert the call that produced the provided `call_info` was successful; fail otherwise"""
assert (
call_info.entry_point_selector is not None
), "An entry point selector must be specified."
entry_point_name = SELECTOR_TO_NAME.get(call_info.entry_point_selector)
assert (
entry_point_name is not None
), f"{call_info.entry_point_selector} isn't defined."
stark_assert(
call_info.result().succeeded,
code=StarknetErrorCode.VALIDATE_FAILURE,
message=(
f"{entry_point_name} call failed; failure reason: "
f"{format_felt_list(call_info.retdata)}."
),
)
Loading