diff --git a/starknet_devnet/server.py b/starknet_devnet/server.py index 3742bfca7..4b528d26c 100644 --- a/starknet_devnet/server.py +++ b/starknet_devnet/server.py @@ -7,14 +7,13 @@ from flask import Flask, request, jsonify, abort from flask.wrappers import Response from flask_cors import CORS -from starkware.starknet.business_logic.internal_transaction import InternalDeploy from starkware.starknet.services.api.gateway.transaction import InvokeFunction, Transaction from starkware.starknet.definitions.transaction_type import TransactionType from starkware.starkware_utils.error_handling import StarkErrorCode, StarkException from werkzeug.datastructures import MultiDict -from .util import DummyExecutionInfo, TxStatus, custom_int, fixed_length_hex, parse_args -from .starknet_wrapper import Choice, StarknetWrapper +from .util import custom_int, fixed_length_hex, parse_args +from .starknet_wrapper import StarknetWrapper from .origin import NullOrigin app = Flask(__name__) @@ -41,42 +40,20 @@ async def add_transaction(): abort(Response(msg, 400)) tx_type = transaction.tx_type.name - status = TxStatus.ACCEPTED_ON_L2 - error_message = None if tx_type == TransactionType.DEPLOY.name: - state = await starknet_wrapper.get_state() - deploy_transaction: InternalDeploy = InternalDeploy.from_external(transaction, state.general_config) - contract_address = deploy_transaction.contract_address - transaction_hash = await starknet_wrapper.deploy(deploy_transaction) - + contract_address, transaction_hash = await starknet_wrapper.deploy(transaction) + result_dict = {} elif tx_type == TransactionType.INVOKE_FUNCTION.name: - transaction: InvokeFunction = transaction - contract_address = transaction.contract_address - execution_info = DummyExecutionInfo() - try: - _, execution_info = await starknet_wrapper.call_or_invoke( - Choice.INVOKE, - transaction - ) - except StarkException as err: - error_message = err.message - status = TxStatus.REJECTED - - transaction_hash = await starknet_wrapper.store_wrapper_transaction( - transaction=transaction, - status=status, - execution_info=execution_info, - error_message=error_message - ) - + contract_address, transaction_hash, result_dict = await starknet_wrapper.invoke(transaction) else: abort(Response(f"Invalid tx_type: {tx_type}.", 400)) return jsonify({ "code": StarkErrorCode.TRANSACTION_RECEIVED.name, - "transaction_hash": transaction_hash, - "address": fixed_length_hex(contract_address) + "transaction_hash": fixed_length_hex(transaction_hash), + "address": fixed_length_hex(contract_address), + **result_dict }) @app.route("/feeder_gateway/get_contract_addresses", methods=["GET"]) @@ -93,10 +70,7 @@ async def call_contract(): raw_data = request.get_data() try: call_specifications = InvokeFunction.loads(raw_data) - result_dict, _ = await starknet_wrapper.call_or_invoke( - Choice.CALL, - call_specifications - ) + result_dict = await starknet_wrapper.call(call_specifications) except StarkException as err: # code 400 would make more sense, but alpha returns 500 abort(Response(err.message, 500)) diff --git a/starknet_devnet/starknet_wrapper.py b/starknet_devnet/starknet_wrapper.py index b18a6002b..68438fa12 100644 --- a/starknet_devnet/starknet_wrapper.py +++ b/starknet_devnet/starknet_wrapper.py @@ -7,7 +7,7 @@ from copy import deepcopy from typing import Dict -from starkware.starknet.business_logic.internal_transaction import InternalDeploy +from starkware.starknet.business_logic.internal_transaction import InternalDeploy, InternalInvokeFunction, InternalTransaction from starkware.starknet.business_logic.state import CarriedState from starkware.starknet.definitions.transaction_type import TransactionType from starkware.starknet.services.api.gateway.transaction import InvokeFunction, Transaction @@ -58,7 +58,7 @@ async def get_starknet(self): await self.__preserve_current_state(self.__starknet.state.state) return self.__starknet - async def get_state(self): + async def __get_state(self): """ Returns the StarknetState of the underlyling Starknet instance, creating the instance first if necessary. @@ -69,7 +69,7 @@ async def get_state(self): async def __update_state(self): previous_state = self.__current_carried_state assert previous_state is not None - current_carried_state = (await self.get_state()).state + current_carried_state = (await self.__get_state()).state updated_shared_state = await current_carried_state.shared_state.apply_state_updates( ffc=current_carried_state.ffc, previous_carried_state=previous_state, @@ -77,10 +77,9 @@ async def __update_state(self): ) self.__starknet.state.state.shared_state = updated_shared_state await self.__preserve_current_state(self.__starknet.state.state) - # await self.preserve_carried_state(current_carried_state) async def __get_state_root(self): - state = await self.get_state() + state = await self.__get_state() return state.state.shared_state.contract_states.root.hex() def __is_contract_deployed(self, address: int) -> bool: @@ -93,63 +92,92 @@ def __get_contract_wrapper(self, address: int) -> ContractWrapper: return self.__address2contract_wrapper[address] - async def deploy(self, transaction: InternalDeploy): + async def deploy(self, transaction: Transaction): """ - Deploys the contract specified with `transaction` and returns tx hash in hex. + Deploys the contract specified with `transaction`. + Returns (contract_address, transaction_hash). """ + state = await self.__get_state() + deploy_transaction: InternalDeploy = InternalDeploy.from_external(transaction, state.general_config) + starknet = await self.get_starknet() - status = TxStatus.ACCEPTED_ON_L2 - error_message = None try: contract = await starknet.deploy( - contract_def=transaction.contract_definition, - constructor_calldata=transaction.constructor_calldata, - contract_address_salt=transaction.contract_address_salt + contract_def=deploy_transaction.contract_definition, + constructor_calldata=deploy_transaction.constructor_calldata, + contract_address_salt=deploy_transaction.contract_address_salt ) + # Uncomment this once contract has execution_info + # execution_info = contract.execution_info + execution_info = DummyExecutionInfo() + status = TxStatus.ACCEPTED_ON_L2 + error_message = None + await self.__update_state() except StarkException as err: error_message = err.message status = TxStatus.REJECTED + execution_info = DummyExecutionInfo() - transaction_hash = await self.store_wrapper_transaction( - transaction, + await self.__store_transaction( + internal_tx=deploy_transaction, status=status, - execution_info=DummyExecutionInfo(), + execution_info=execution_info, error_message=error_message ) - await self.__update_state() - self.__address2contract_wrapper[contract.contract_address] = ContractWrapper(contract, transaction.contract_definition) - return transaction_hash + self.__address2contract_wrapper[contract.contract_address] = ContractWrapper(contract, deploy_transaction.contract_definition) + return deploy_transaction.contract_address, deploy_transaction.hash_value - async def call_or_invoke(self, choice: Choice, specifications: InvokeFunction): - """ - Performs `ContractWrapper.call_or_invoke` on the contract at `contract_address`. - If `choice` is INVOKE, updates the state. - Returns a tuple of: - - `dict` with `"result"`, holding the adapted result - - `execution_info` - """ - contract_wrapper = self.__get_contract_wrapper(specifications.contract_address) - adapted_result, execution_info = await contract_wrapper.call_or_invoke( - choice, - entry_point_selector=specifications.entry_point_selector, - calldata=specifications.calldata, - signature=specifications.signature - ) + async def invoke(self, transaction: InvokeFunction): + """Perform invoke according to specifications in `transaction`.""" + state = await self.__get_state() + invoke_transaction: InternalInvokeFunction = InternalInvokeFunction.from_external(transaction, state.general_config) - if choice == Choice.INVOKE: + try: + contract_wrapper = self.__get_contract_wrapper(invoke_transaction.contract_address) + adapted_result, execution_info = await contract_wrapper.call_or_invoke( + Choice.INVOKE, + entry_point_selector=invoke_transaction.entry_point_selector, + calldata=invoke_transaction.calldata, + signature=invoke_transaction.signature + ) + status = TxStatus.ACCEPTED_ON_L2 + error_message = None await self.__update_state() + except StarkException as err: + error_message = err.message + status = TxStatus.REJECTED + execution_info = DummyExecutionInfo() + adapted_result = {} + + await self.__store_transaction( + internal_tx=invoke_transaction, + status=status, + execution_info=execution_info, + error_message=error_message + ) + + return transaction.contract_address, invoke_transaction.hash_value, { "result": adapted_result } + + async def call(self, transaction: InvokeFunction): + """Perform call according to specifications in `transaction`.""" + contract_wrapper = self.__get_contract_wrapper(transaction.contract_address) + adapted_result, _ = await contract_wrapper.call_or_invoke( + Choice.CALL, + entry_point_selector=transaction.entry_point_selector, + calldata=transaction.calldata, + signature=transaction.signature + ) - return { "result": adapted_result }, execution_info + return { "result": adapted_result } def get_transaction_status(self, transaction_hash: str): """Returns the status of the transaction identified by `transaction_hash`.""" - tx_hash_int = int(transaction_hash,16) + tx_hash_int = int(transaction_hash, 16) if tx_hash_int in self.__transaction_wrappers: - transaction_wrapper = self.__transaction_wrappers[tx_hash_int] transaction = transaction_wrapper.transaction @@ -174,7 +202,6 @@ def get_transaction(self, transaction_hash: str): tx_hash_int = int(transaction_hash,16) if tx_hash_int in self.__transaction_wrappers: - return self.__transaction_wrappers[tx_hash_int].transaction return self.origin.get_transaction(transaction_hash) @@ -184,7 +211,6 @@ def get_transaction_receipt(self, transaction_hash: str): tx_hash_int = int(transaction_hash,16) if tx_hash_int in self.__transaction_wrappers: - return self.__transaction_wrappers[tx_hash_int].receipt return { @@ -200,16 +226,15 @@ def get_number_of_blocks(self): async def __generate_block(self, transaction: dict, receipt: dict): """ Generates a block and stores it to blocks and hash2block. The block contains just the passed transaction. + Also modifies the `transaction` and `receipt` objects received. The `transaction` dict should also contain a key `transaction`. + Returns (block_hash, block_number). """ block_number = self.get_number_of_blocks() block_hash = hex(block_number) state_root = await self.__get_state_root() - transaction["block_hash"] = receipt["block_hash"] = block_hash - transaction["block_number"] = receipt["block_number"] = block_number - block = { "block_hash": block_hash, "block_number": block_number, @@ -225,6 +250,8 @@ async def __generate_block(self, transaction: dict, receipt: dict): self.__num2block[number_of_blocks] = block self.__hash2block[int(block_hash, 16)] = block + return block_hash, block_number + def __get_last_block(self): number_of_blocks = self.get_number_of_blocks() return self.get_block_by_number(number_of_blocks - 1) @@ -257,32 +284,26 @@ def get_block_by_number(self, block_number: int): return self.origin.get_block_by_number(block_number) - async def __store_transaction(self, transaction_wrapper: TransactionWrapper, error_message): - - if transaction_wrapper.transaction["status"] == TxStatus.REJECTED: - transaction_wrapper.set_transaction_failure(error_message) - else: - await self.__generate_block(transaction_wrapper.transaction, transaction_wrapper.receipt) - - self.__transaction_wrappers[int(transaction_wrapper.transaction_hash,16)] = transaction_wrapper - - return transaction_wrapper.transaction_hash - - async def store_wrapper_transaction(self, transaction: Transaction, status: TxStatus, + async def __store_transaction(self, internal_tx: InternalTransaction, status: TxStatus, execution_info: StarknetTransactionExecutionInfo, error_message: str=None ): """Stores the provided data as a deploy transaction in `self.transactions`.""" - - starknet = await self.get_starknet() - - if transaction.tx_type == TransactionType.DEPLOY: - tx_wrapper = DeployTransactionWrapper(transaction,status,starknet) + if internal_tx.tx_type == TransactionType.DEPLOY: + tx_wrapper = DeployTransactionWrapper(internal_tx, status, execution_info) + elif internal_tx.tx_type == TransactionType.INVOKE_FUNCTION: + tx_wrapper = InvokeTransactionWrapper(internal_tx, status, execution_info) else: - tx_wrapper = InvokeTransactionWrapper(transaction,status,starknet) + raise StarknetDevnetException(message=f"Illegal tx_type: {internal_tx.tx_type}") - tx_wrapper.generate_receipt(execution_info) + if status == TxStatus.REJECTED: + assert error_message, "error_message must be present if tx rejected" + tx_wrapper.set_failure_reason(error_message) + else: + block_hash, block_number = await self.__generate_block(tx_wrapper.transaction, tx_wrapper.receipt) + tx_wrapper.set_block_data(block_hash, block_number) - return await self.__store_transaction(tx_wrapper, error_message) + numeric_hash = int(tx_wrapper.transaction_hash, 16) + self.__transaction_wrappers[numeric_hash] = tx_wrapper def get_code(self, contract_address: int) -> dict: """Returns a `dict` with `abi` and `bytecode` of the contract at `contract_address`.""" @@ -296,7 +317,7 @@ async def get_storage_at(self, contract_address: int, key: int) -> str: Returns the storage identified by `key` from the contract at `contract_address`. """ - state = await self.get_state() + state = await self.__get_state() contract_states = state.state.contract_states state = contract_states[contract_address] diff --git a/starknet_devnet/transaction_wrapper.py b/starknet_devnet/transaction_wrapper.py index 269fb22c2..d31743f1a 100644 --- a/starknet_devnet/transaction_wrapper.py +++ b/starknet_devnet/transaction_wrapper.py @@ -3,88 +3,111 @@ """ from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List +from starkware.starknet.business_logic.internal_transaction import InternalDeploy, InternalInvokeFunction from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.definitions.transaction_type import TransactionType +from starkware.starknet.testing.objects import StarknetTransactionExecutionInfo -from .util import fixed_length_hex +from .util import TxStatus, fixed_length_hex + +@dataclass +class TransactionDetails(ABC): + """Base class for `DeployTransactionDetails` and `InvokeTransactionDetails`.""" + type: str + contract_address: str + transaction_hash: str + + def to_dict(self): + """Get details in JSON/dict format.""" + return dict(self.__dict__) + +@dataclass +class DeployTransactionDetails(TransactionDetails): + """Transaction details of `DeployTransaction`.""" + constructor_calldata: List[str] + contract_address_salt: str + + +@dataclass +class InvokeTransactionDetails(TransactionDetails): + """Transcation details of `InvokeTransaction`.""" + calldata: List[str] + entry_point_selector: str class TransactionWrapper(ABC): """Transaction Wrapper base class.""" @abstractmethod - def __init__(self): - self.transaction = {} - self.receipt = {} - self.transaction_hash = None + def __init__( + self, status: TxStatus, execution_info: StarknetTransactionExecutionInfo, tx_details: TransactionDetails + ): + self.transaction_hash = tx_details.transaction_hash - def generate_transaction(self, internal_transaction, status, transaction_type, **transaction_details): - """Creates the transaction object""" self.transaction = { "status": status.name, - "transaction": { - "contract_address": fixed_length_hex(internal_transaction.contract_address), - "transaction_hash": self.transaction_hash, - "type": transaction_type.name, - **transaction_details - }, - "transaction_index": 0 # always the first (and only) tx in the block + "transaction": tx_details.to_dict(), + "transaction_index": 0 # always the first (and only) tx in the block } - def generate_receipt(self, execution_info): - """Creates the receipt for the transaction""" - self.receipt = { "execution_resources": execution_info.call_info.cairo_usage, "l2_to_l1_messages": execution_info.l2_to_l1_messages, - "status": self.transaction["status"], - "transaction_hash": self.transaction_hash, + "status": status.name, + "transaction_hash": tx_details.transaction_hash, "transaction_index": 0 # always the first (and only) tx in the block } - def set_transaction_failure(self, error_message: str): - """Creates a new entry `failure_key` in the transaction object with the transaction failure reason data.""" + def set_block_data(self, block_hash: str, block_number: int): + """Sets `block_hash` and `block_number` to the wrapped transaction and receipt.""" + self.transaction["block_hash"] = self.receipt["block_hash"] = block_hash + self.transaction["block_number"] = self.receipt["block_number"] = block_number + def set_failure_reason(self, error_message: str): + """Sets the failure reason to transaction and receipt dicts.""" + assert error_message + assert self.transaction + assert self.receipt failure_key = "transaction_failure_reason" self.transaction[failure_key] = self.receipt[failure_key] = { - "code": StarknetErrorCode.TRANSACTION_FAILED.name, - "error_message": error_message, - "tx_id": self.transaction_hash + "code": StarknetErrorCode.TRANSACTION_FAILED.name, + "error_message": error_message, + "tx_id": self.transaction_hash } class DeployTransactionWrapper(TransactionWrapper): - """Class for Deploy Transaction.""" - - def __init__(self, internal_deploy, status, starknet): - - super().__init__() - - self.transaction_hash = hex(internal_deploy.to_external().calculate_hash(starknet.state.general_config)) + """Wrapper of Deploy Transaction.""" - self.generate_transaction( - internal_deploy, + def __init__(self, internal_tx: InternalDeploy, status: TxStatus, execution_info: StarknetTransactionExecutionInfo): + super().__init__( status, - TransactionType.DEPLOY, - constructor_calldata=[str(arg) for arg in internal_deploy.constructor_calldata], - contract_address_salt=hex(internal_deploy.contract_address_salt) + execution_info, + DeployTransactionDetails( + TransactionType.DEPLOY.name, + contract_address=fixed_length_hex(internal_tx.contract_address), + transaction_hash=fixed_length_hex(internal_tx.hash_value), + constructor_calldata=[str(arg) for arg in internal_tx.constructor_calldata], + contract_address_salt=hex(internal_tx.contract_address_salt) + ) ) class InvokeTransactionWrapper(TransactionWrapper): - """Class for Invoke Transaction.""" - - def __init__(self, internal_transaction, status, starknet): - - super().__init__() - - self.transaction_hash = hex(internal_transaction.calculate_hash(starknet.state.general_config)) + """Wrapper of Invoke Transaction.""" - self.generate_transaction( - internal_transaction, + def __init__(self, internal_tx: InternalInvokeFunction, status: TxStatus, execution_info: StarknetTransactionExecutionInfo): + super().__init__( status, - TransactionType.INVOKE_FUNCTION, - calldata=[str(arg) for arg in internal_transaction.calldata], - entry_point_selector=str(internal_transaction.entry_point_selector) + execution_info, + InvokeTransactionDetails( + TransactionType.INVOKE_FUNCTION.name, + contract_address=fixed_length_hex(internal_tx.contract_address), + transaction_hash=fixed_length_hex(internal_tx.hash_value), + calldata=[str(arg) for arg in internal_tx.calldata], + entry_point_selector=str(internal_tx.entry_point_selector) + ) )