Skip to content
This repository has been archived by the owner on Dec 15, 2023. It is now read-only.

Commit

Permalink
Call raw (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
badurinantun authored May 13, 2022
1 parent e4d0b29 commit 687b18e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 64 deletions.
66 changes: 13 additions & 53 deletions starknet_devnet/contract_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,17 @@
from dataclasses import dataclass
from typing import List

from starkware.starknet.services.api.contract_definition import ContractDefinition, EntryPointType
from starkware.starknet.services.api.contract_definition import ContractDefinition
from starkware.starknet.testing.contract import StarknetContract
from starkware.starknet.utils.api_utils import cast_to_felts
from starkware.starknet.business_logic.internal_transaction import InternalInvokeFunction
from starkware.starknet.definitions import constants
from starkware.starknet.public.abi import get_selector_from_name
from starkware.starknet.business_logic.execution.execute_entry_point import ExecuteEntryPoint
from starkware.starknet.business_logic.execution.objects import (
TransactionExecutionContext,
TransactionExecutionInfo,
)
from starkware.starknet.testing.state import StarknetState

from starknet_devnet.util import Choice

async def call_internal_tx(starknet_state: StarknetState, internal_tx: InternalInvokeFunction):
"""
Executes an internal transaction.
Expand Down Expand Up @@ -67,29 +63,7 @@ def __init__(self, contract: StarknetContract, contract_definition: ContractDefi
"bytecode": self.contract_definition["program"]["data"]
}


# pylint: disable=too-many-arguments
async def call_or_invoke(
self,
choice: Choice,
entry_point_selector: int,
calldata: List[int],
signature: List[int],
caller_address: int,
max_fee: int
):
"""
Depending on `choice`, performs the call or invoke of the function
identified with `entry_point_selector`, potentially passing in `calldata` and `signature`.
"""
if choice == Choice.CALL:
execution_info = await self.call(entry_point_selector, calldata, signature, caller_address, max_fee)
else:
execution_info = await self.invoke(entry_point_selector, calldata, signature, caller_address, max_fee)

result = list(map(hex, execution_info.call_info.retdata))
return result, execution_info

async def call(
self,
entry_point_selector: int,
Expand All @@ -101,36 +75,19 @@ async def call(
"""
Calls the function identified with `entry_point_selector`, potentially passing in `calldata` and `signature`.
"""
starknet_state = self.contract.state.copy()
contract_address = self.contract.contract_address
selector = entry_point_selector

if isinstance(contract_address, str):
contract_address = int(contract_address, 16)
assert isinstance(contract_address, int)

if isinstance(selector, str):
selector = get_selector_from_name(selector)
assert isinstance(selector, int)

if signature is None:
signature = []

internal_tx = InternalInvokeFunction.create(
contract_address=contract_address,
entry_point_selector=selector,
entry_point_type=EntryPointType.EXTERNAL,
call_info = await self.contract.state.call_raw(
calldata=calldata,
max_fee=max_fee,
signature=signature,
caller_address=caller_address,
nonce=None,
chain_id=starknet_state.general_config.chain_id.value,
version=constants.QUERY_VERSION,
only_query=True,
contract_address=self.contract.contract_address,
max_fee=max_fee,
selector=entry_point_selector,
signature=signature and cast_to_felts(values=signature)
)

return await call_internal_tx(starknet_state, internal_tx)
result = list(map(hex, call_info.retdata))

return result

async def invoke(
self,
Expand All @@ -144,11 +101,14 @@ async def invoke(
Invokes the function identified with `entry_point_selector`, potentially passing in `calldata` and `signature`.
"""

return await self.contract.state.invoke_raw(
execution_info = await self.contract.state.invoke_raw(
contract_address=self.contract.contract_address,
selector=entry_point_selector,
calldata=calldata,
caller_address=caller_address,
max_fee=max_fee,
signature=signature and cast_to_felts(values=signature),
)

result = list(map(hex, execution_info.call_info.retdata))
return result, execution_info
8 changes: 3 additions & 5 deletions starknet_devnet/starknet_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .origin import NullOrigin, Origin
from .general_config import DEFAULT_GENERAL_CONFIG
from .util import (
Choice, StarknetDevnetException, TxStatus, DummyExecutionInfo,
StarknetDevnetException, TxStatus, DummyExecutionInfo,
fixed_length_hex, enable_pickling, generate_state_update
)
from .contract_wrapper import ContractWrapper, call_internal_tx
Expand Down Expand Up @@ -206,8 +206,7 @@ async def invoke(self, transaction: InvokeFunction):
raise StarknetDevnetException(message=message)

contract_wrapper = self.__get_contract_wrapper(invoke_transaction.contract_address)
adapted_result, execution_info = await contract_wrapper.call_or_invoke(
Choice.INVOKE,
adapted_result, execution_info = await contract_wrapper.invoke(
entry_point_selector=invoke_transaction.entry_point_selector,
calldata=invoke_transaction.calldata,
signature=invoke_transaction.signature,
Expand Down Expand Up @@ -238,8 +237,7 @@ async def call(self, transaction: InvokeFunction):
"""Perform call according to specifications in `transaction`."""
contract_wrapper = self.__get_contract_wrapper(transaction.contract_address)

adapted_result, _ = await contract_wrapper.call_or_invoke(
Choice.CALL,
adapted_result = await contract_wrapper.call(
entry_point_selector=transaction.entry_point_selector,
calldata=transaction.calldata,
signature=transaction.signature,
Expand Down
6 changes: 0 additions & 6 deletions starknet_devnet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,6 @@ class TxStatus(Enum):
"""The transaction was accepted on-chain."""


class Choice(Enum):
"""Enumerates ways of interacting with a Starknet function."""
CALL = "call"
INVOKE = "invoke"


def custom_int(arg: str) -> str:
"""
Converts the argument to an integer.
Expand Down

0 comments on commit 687b18e

Please sign in to comment.