Skip to content
This repository has been archived by the owner on Dec 15, 2023. It is now read-only.

Enable using default_entrypoint by switching to invoke_raw #78

Merged
merged 11 commits into from
Apr 13, 2022
123 changes: 0 additions & 123 deletions starknet_devnet/adapt.py

This file was deleted.

57 changes: 24 additions & 33 deletions starknet_devnet/contract_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,11 @@
from dataclasses import dataclass
from typing import List

from starkware.starknet.public.abi import get_selector_from_name
from starkware.starknet.services.api.contract_definition import ContractDefinition
from starkware.starknet.testing.contract import StarknetContract
from starkware.starknet.testing.objects import StarknetTransactionExecutionInfo
from starkware.starknet.utils.api_utils import cast_to_felts

from starknet_devnet.adapt import adapt_calldata, adapt_output
from starknet_devnet.util import Choice, StarknetDevnetException

def extract_types(abi):
"""
Extracts the types (structs) used in the contract whose ABI is provided.
"""

structs = [entry for entry in abi if entry["type"] == "struct"]
type_dict = { struct["name"]: struct for struct in structs }
return type_dict
from starknet_devnet.util import Choice

@dataclass
class ContractWrapper:
Expand All @@ -36,29 +25,31 @@ def __init__(self, contract: StarknetContract, contract_definition: ContractDefi
"bytecode": self.contract_definition["program"]["data"]
}

self.types: dict = extract_types(contract_definition.abi)

async def call_or_invoke(self, choice: Choice, entry_point_selector: int, calldata: List[int], signature: List[int]):
# pylint: disable=too-many-arguments
async def call_or_invoke(
self,
choice: Choice,
entry_point_selector: int,
calldata: List[int],
signature: List[int],
caller_address: int,
max_fee: int
):
"""
Depending on `choice`, performs the call or invoke of the function
identified with `entry_point_selector`, potentially passing in `calldata` and `signature`.
"""
function_mapping = self.contract._abi_function_mapping # pylint: disable=protected-access
for method_name in function_mapping:
selector = get_selector_from_name(method_name)
if selector == entry_point_selector:
try:
method = getattr(self.contract, method_name)
except NotImplementedError as nie:
raise StarknetDevnetException from nie
function_abi = function_mapping[method_name]
break
else:
raise StarknetDevnetException(message=f"Illegal method selector: {entry_point_selector}.")

adapted_calldata = adapt_calldata(calldata, function_abi["inputs"], self.types)
state = self.contract.state.copy() if choice == Choice.CALL else self.contract.state

execution_info = await state.invoke_raw(
contract_address=self.contract.contract_address,
selector=entry_point_selector,
calldata=calldata,
caller_address=caller_address,
max_fee=max_fee,
signature=None if signature is None else cast_to_felts(values=signature),
badurinantun marked this conversation as resolved.
Show resolved Hide resolved
)

prepared = method(*adapted_calldata)
called = getattr(prepared, choice.value)
execution_info: StarknetTransactionExecutionInfo = await called(signature=signature)
return adapt_output(execution_info.result), execution_info
result = list(map(hex, execution_info.call_info.retdata))
return result, execution_info
8 changes: 6 additions & 2 deletions starknet_devnet/starknet_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ async def invoke(self, transaction: InvokeFunction):
Choice.INVOKE,
entry_point_selector=invoke_transaction.entry_point_selector,
calldata=invoke_transaction.calldata,
signature=invoke_transaction.signature
signature=invoke_transaction.signature,
caller_address=invoke_transaction.caller_address,
max_fee=invoke_transaction.max_fee
)
status = TxStatus.ACCEPTED_ON_L2
error_message = None
Expand Down Expand Up @@ -233,7 +235,9 @@ async def call(self, transaction: InvokeFunction):
Choice.CALL,
entry_point_selector=transaction.entry_point_selector,
calldata=transaction.calldata,
signature=transaction.signature
signature=transaction.signature,
caller_address=0,
max_fee=transaction.max_fee
)

return { "result": adapted_result }
Expand Down
54 changes: 33 additions & 21 deletions starknet_devnet/transaction_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from typing import List

from starkware.starknet.business_logic.internal_transaction import InternalInvokeFunction
from starkware.starknet.business_logic.execution.objects import Event, L2ToL1MessageInfo
from starkware.starknet.services.api.feeder_gateway.response_objects import FunctionInvocation
from starkware.starknet.services.api.gateway.transaction import Deploy
from starkware.starknet.definitions.error_codes import StarknetErrorCode
from starkware.starknet.definitions.transaction_type import TransactionType
from starkware.starknet.testing.objects import StarknetTransactionExecutionInfo
from starkware.starknet.testing.objects import StarknetTransactionExecutionInfo, TransactionExecutionInfo

from .util import TxStatus, fixed_length_hex
from .constants import FAILURE_REASON_KEY
Expand Down Expand Up @@ -43,48 +45,51 @@ class InvokeTransactionDetails(TransactionDetails):
entry_point_selector: str
entry_point_type: str

def get_events(execution_info: StarknetTransactionExecutionInfo):
"""Extract events if any; stringify content."""
if not hasattr(execution_info, "raw_events"):
return []
events = []
for event in execution_info.raw_events:
events.append({
def process_events(events: List[Event]):
"""Extract events and hex the content."""

processed_events = []
for event in events:
processed_events.append({
"from_address": hex(event.from_address),
"data": [hex(d) for d in event.data],
"keys": [hex(key) for key in event.keys]
})
return events
return processed_events

class TransactionWrapper(ABC):
"""Transaction Wrapper base class."""

# pylint: disable=too-many-arguments
@abstractmethod
def __init__(
self, status: TxStatus, execution_info: StarknetTransactionExecutionInfo, tx_details: TransactionDetails
self,
status: TxStatus,
call_info: FunctionInvocation,
badurinantun marked this conversation as resolved.
Show resolved Hide resolved
tx_details: TransactionDetails,
events: List[Event],
l2_to_l1_messages: List[L2ToL1MessageInfo]
):
self.transaction_hash = tx_details.transaction_hash

events = get_events(execution_info)

self.transaction = {
"status": status.name,
"transaction": tx_details.to_dict(),
"transaction_index": 0 # always the first (and only) tx in the block
}

self.receipt = {
"execution_resources": execution_info.call_info.execution_resources,
"l2_to_l1_messages": execution_info.l2_to_l1_messages,
"events": events,
"execution_resources": call_info.execution_resources,
"l2_to_l1_messages": l2_to_l1_messages,
"events": process_events(events),
"status": status.name,
"transaction_hash": tx_details.transaction_hash,
"transaction_index": 0 # always the first (and only) tx in the block
}

if status is not TxStatus.REJECTED:
self.trace = {
"function_invocation": execution_info.call_info.dump(),
"function_invocation": call_info.dump(),
"signature": tx_details.to_dict().get("signature", [])
}

Expand Down Expand Up @@ -130,25 +135,30 @@ def __init__(
):
super().__init__(
status,
execution_info,
execution_info.call_info,
DeployTransactionDetails(
TransactionType.DEPLOY.name,
contract_address=fixed_length_hex(contract_address),
transaction_hash=fixed_length_hex(tx_hash),
constructor_calldata=[hex(arg) for arg in transaction.constructor_calldata],
contract_address_salt=hex(transaction.contract_address_salt),
class_hash=fixed_length_hex(int.from_bytes(contract_hash, "big"))
)
),
events=execution_info.raw_events,
l2_to_l1_messages=execution_info.l2_to_l1_messages
)


class InvokeTransactionWrapper(TransactionWrapper):
"""Wrapper of Invoke Transaction."""

def __init__(self, internal_tx: InternalInvokeFunction, status: TxStatus, execution_info: StarknetTransactionExecutionInfo):
def __init__(self, internal_tx: InternalInvokeFunction, status: TxStatus, execution_info: TransactionExecutionInfo):
call_info = execution_info.call_info
if status is not TxStatus.REJECTED:
call_info = FunctionInvocation.from_internal_version(call_info)
super().__init__(
status,
execution_info,
call_info,
InvokeTransactionDetails(
TransactionType.INVOKE_FUNCTION.name,
contract_address=fixed_length_hex(internal_tx.contract_address),
Expand All @@ -157,5 +167,7 @@ def __init__(self, internal_tx: InternalInvokeFunction, status: TxStatus, execut
entry_point_selector=fixed_length_hex(internal_tx.entry_point_selector),
entry_point_type=internal_tx.entry_point_type.name,
signature=[hex(sig_part) for sig_part in internal_tx.signature]
)
),
events=execution_info.get_sorted_events(),
l2_to_l1_messages=execution_info.get_sorted_l2_to_l1_messages()
)
9 changes: 9 additions & 0 deletions starknet_devnet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ def __init__(self):
self.retdata = []
self.internal_calls = []
self.l2_to_l1_messages = []
self.raw_events = []

def get_sorted_events(self):
"""Return empty list"""
return self.raw_events

def get_sorted_l2_to_l1_messages(self):
"""Return empty list"""
return self.l2_to_l1_messages

def enable_pickling():
"""
Expand Down
Loading