Skip to content
This repository has been archived by the owner on Dec 15, 2023. It is now read-only.

Refactor StarknetWrapper and TransactionWrapper #27

Merged
merged 6 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 9 additions & 35 deletions starknet_devnet/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
from flask import Flask, request, jsonify, abort
from flask.wrappers import Response
from flask_cors import CORS
from starkware.starknet.business_logic.internal_transaction import InternalDeploy
from starkware.starknet.services.api.gateway.transaction import InvokeFunction, Transaction
from starkware.starknet.definitions.transaction_type import TransactionType
from starkware.starkware_utils.error_handling import StarkErrorCode, StarkException
from werkzeug.datastructures import MultiDict

from .util import DummyExecutionInfo, TxStatus, custom_int, fixed_length_hex, parse_args
from .starknet_wrapper import Choice, StarknetWrapper
from .util import custom_int, fixed_length_hex, parse_args
from .starknet_wrapper import StarknetWrapper
from .origin import NullOrigin

app = Flask(__name__)
Expand All @@ -41,42 +40,20 @@ async def add_transaction():
abort(Response(msg, 400))

tx_type = transaction.tx_type.name
status = TxStatus.ACCEPTED_ON_L2
error_message = None

if tx_type == TransactionType.DEPLOY.name:
state = await starknet_wrapper.get_state()
deploy_transaction: InternalDeploy = InternalDeploy.from_external(transaction, state.general_config)
contract_address = deploy_transaction.contract_address
transaction_hash = await starknet_wrapper.deploy(deploy_transaction)

contract_address, transaction_hash = await starknet_wrapper.deploy(transaction)
result_dict = {}
elif tx_type == TransactionType.INVOKE_FUNCTION.name:
transaction: InvokeFunction = transaction
contract_address = transaction.contract_address
execution_info = DummyExecutionInfo()
try:
_, execution_info = await starknet_wrapper.call_or_invoke(
Choice.INVOKE,
transaction
)
except StarkException as err:
error_message = err.message
status = TxStatus.REJECTED

transaction_hash = await starknet_wrapper.store_wrapper_transaction(
transaction=transaction,
status=status,
execution_info=execution_info,
error_message=error_message
)

contract_address, transaction_hash, result_dict = await starknet_wrapper.invoke(transaction)
else:
abort(Response(f"Invalid tx_type: {tx_type}.", 400))

return jsonify({
"code": StarkErrorCode.TRANSACTION_RECEIVED.name,
"transaction_hash": transaction_hash,
"address": fixed_length_hex(contract_address)
"transaction_hash": fixed_length_hex(transaction_hash),
"address": fixed_length_hex(contract_address),
**result_dict
})

@app.route("/feeder_gateway/get_contract_addresses", methods=["GET"])
Expand All @@ -93,10 +70,7 @@ async def call_contract():
raw_data = request.get_data()
try:
call_specifications = InvokeFunction.loads(raw_data)
result_dict, _ = await starknet_wrapper.call_or_invoke(
Choice.CALL,
call_specifications
)
result_dict = await starknet_wrapper.call(call_specifications)
except StarkException as err:
# code 400 would make more sense, but alpha returns 500
abort(Response(err.message, 500))
Expand Down
147 changes: 84 additions & 63 deletions starknet_devnet/starknet_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from copy import deepcopy
from typing import Dict

from starkware.starknet.business_logic.internal_transaction import InternalDeploy
from starkware.starknet.business_logic.internal_transaction import InternalDeploy, InternalInvokeFunction, InternalTransaction
from starkware.starknet.business_logic.state import CarriedState
from starkware.starknet.definitions.transaction_type import TransactionType
from starkware.starknet.services.api.gateway.transaction import InvokeFunction, Transaction
Expand Down Expand Up @@ -58,7 +58,7 @@ async def get_starknet(self):
await self.__preserve_current_state(self.__starknet.state.state)
return self.__starknet

async def get_state(self):
async def __get_state(self):
"""
Returns the StarknetState of the underlyling Starknet instance,
creating the instance first if necessary.
Expand All @@ -69,18 +69,17 @@ async def get_state(self):
async def __update_state(self):
previous_state = self.__current_carried_state
assert previous_state is not None
current_carried_state = (await self.get_state()).state
current_carried_state = (await self.__get_state()).state
updated_shared_state = await current_carried_state.shared_state.apply_state_updates(
ffc=current_carried_state.ffc,
previous_carried_state=previous_state,
current_carried_state=current_carried_state
)
self.__starknet.state.state.shared_state = updated_shared_state
await self.__preserve_current_state(self.__starknet.state.state)
# await self.preserve_carried_state(current_carried_state)

async def __get_state_root(self):
state = await self.get_state()
state = await self.__get_state()
return state.state.shared_state.contract_states.root.hex()

def __is_contract_deployed(self, address: int) -> bool:
Expand All @@ -93,63 +92,92 @@ def __get_contract_wrapper(self, address: int) -> ContractWrapper:

return self.__address2contract_wrapper[address]

async def deploy(self, transaction: InternalDeploy):
async def deploy(self, transaction: Transaction):
"""
Deploys the contract specified with `transaction` and returns tx hash in hex.
Deploys the contract specified with `transaction`.
Returns (contract_address, transaction_hash).
"""

state = await self.__get_state()
deploy_transaction: InternalDeploy = InternalDeploy.from_external(transaction, state.general_config)

starknet = await self.get_starknet()
status = TxStatus.ACCEPTED_ON_L2
error_message = None

try:
contract = await starknet.deploy(
contract_def=transaction.contract_definition,
constructor_calldata=transaction.constructor_calldata,
contract_address_salt=transaction.contract_address_salt
contract_def=deploy_transaction.contract_definition,
constructor_calldata=deploy_transaction.constructor_calldata,
contract_address_salt=deploy_transaction.contract_address_salt
)
# Uncomment this once contract has execution_info
# execution_info = contract.execution_info
execution_info = DummyExecutionInfo()
status = TxStatus.ACCEPTED_ON_L2
error_message = None
await self.__update_state()
except StarkException as err:
error_message = err.message
status = TxStatus.REJECTED
execution_info = DummyExecutionInfo()

transaction_hash = await self.store_wrapper_transaction(
transaction,
await self.__store_transaction(
internal_tx=deploy_transaction,
status=status,
execution_info=DummyExecutionInfo(),
execution_info=execution_info,
error_message=error_message
)

await self.__update_state()
self.__address2contract_wrapper[contract.contract_address] = ContractWrapper(contract, transaction.contract_definition)
return transaction_hash
self.__address2contract_wrapper[contract.contract_address] = ContractWrapper(contract, deploy_transaction.contract_definition)
return deploy_transaction.contract_address, deploy_transaction.hash_value

async def call_or_invoke(self, choice: Choice, specifications: InvokeFunction):
"""
Performs `ContractWrapper.call_or_invoke` on the contract at `contract_address`.
If `choice` is INVOKE, updates the state.
Returns a tuple of:
- `dict` with `"result"`, holding the adapted result
- `execution_info`
"""
contract_wrapper = self.__get_contract_wrapper(specifications.contract_address)
adapted_result, execution_info = await contract_wrapper.call_or_invoke(
choice,
entry_point_selector=specifications.entry_point_selector,
calldata=specifications.calldata,
signature=specifications.signature
)
async def invoke(self, transaction: InvokeFunction):
"""Perform invoke according to specifications in `transaction`."""
state = await self.__get_state()
invoke_transaction: InternalInvokeFunction = InternalInvokeFunction.from_external(transaction, state.general_config)

if choice == Choice.INVOKE:
try:
contract_wrapper = self.__get_contract_wrapper(invoke_transaction.contract_address)
adapted_result, execution_info = await contract_wrapper.call_or_invoke(
Choice.INVOKE,
entry_point_selector=invoke_transaction.entry_point_selector,
calldata=invoke_transaction.calldata,
signature=invoke_transaction.signature
)
status = TxStatus.ACCEPTED_ON_L2
error_message = None
await self.__update_state()
except StarkException as err:
error_message = err.message
status = TxStatus.REJECTED
execution_info = DummyExecutionInfo()
adapted_result = {}

await self.__store_transaction(
internal_tx=invoke_transaction,
status=status,
execution_info=execution_info,
error_message=error_message
)

return transaction.contract_address, invoke_transaction.hash_value, { "result": adapted_result }

async def call(self, transaction: InvokeFunction):
"""Perform call according to specifications in `transaction`."""
contract_wrapper = self.__get_contract_wrapper(transaction.contract_address)
adapted_result, _ = await contract_wrapper.call_or_invoke(
Choice.CALL,
entry_point_selector=transaction.entry_point_selector,
calldata=transaction.calldata,
signature=transaction.signature
)

return { "result": adapted_result }, execution_info
return { "result": adapted_result }

def get_transaction_status(self, transaction_hash: str):
"""Returns the status of the transaction identified by `transaction_hash`."""

tx_hash_int = int(transaction_hash,16)
tx_hash_int = int(transaction_hash, 16)
if tx_hash_int in self.__transaction_wrappers:

transaction_wrapper = self.__transaction_wrappers[tx_hash_int]

transaction = transaction_wrapper.transaction
Expand All @@ -174,7 +202,6 @@ def get_transaction(self, transaction_hash: str):

tx_hash_int = int(transaction_hash,16)
if tx_hash_int in self.__transaction_wrappers:

return self.__transaction_wrappers[tx_hash_int].transaction

return self.origin.get_transaction(transaction_hash)
Expand All @@ -184,7 +211,6 @@ def get_transaction_receipt(self, transaction_hash: str):

tx_hash_int = int(transaction_hash,16)
if tx_hash_int in self.__transaction_wrappers:

return self.__transaction_wrappers[tx_hash_int].receipt

return {
Expand All @@ -200,16 +226,15 @@ def get_number_of_blocks(self):
async def __generate_block(self, transaction: dict, receipt: dict):
"""
Generates a block and stores it to blocks and hash2block. The block contains just the passed transaction.
Also modifies the `transaction` and `receipt` objects received.
The `transaction` dict should also contain a key `transaction`.
Returns (block_hash, block_number).
"""

block_number = self.get_number_of_blocks()
block_hash = hex(block_number)
state_root = await self.__get_state_root()

transaction["block_hash"] = receipt["block_hash"] = block_hash
transaction["block_number"] = receipt["block_number"] = block_number

block = {
"block_hash": block_hash,
"block_number": block_number,
Expand All @@ -225,6 +250,8 @@ async def __generate_block(self, transaction: dict, receipt: dict):
self.__num2block[number_of_blocks] = block
self.__hash2block[int(block_hash, 16)] = block

return block_hash, block_number

def __get_last_block(self):
number_of_blocks = self.get_number_of_blocks()
return self.get_block_by_number(number_of_blocks - 1)
Expand Down Expand Up @@ -257,32 +284,26 @@ def get_block_by_number(self, block_number: int):

return self.origin.get_block_by_number(block_number)

async def __store_transaction(self, transaction_wrapper: TransactionWrapper, error_message):

if transaction_wrapper.transaction["status"] == TxStatus.REJECTED:
transaction_wrapper.set_transaction_failure(error_message)
else:
await self.__generate_block(transaction_wrapper.transaction, transaction_wrapper.receipt)

self.__transaction_wrappers[int(transaction_wrapper.transaction_hash,16)] = transaction_wrapper

return transaction_wrapper.transaction_hash

async def store_wrapper_transaction(self, transaction: Transaction, status: TxStatus,
async def __store_transaction(self, internal_tx: InternalTransaction, status: TxStatus,
execution_info: StarknetTransactionExecutionInfo, error_message: str=None
):
"""Stores the provided data as a deploy transaction in `self.transactions`."""

starknet = await self.get_starknet()

if transaction.tx_type == TransactionType.DEPLOY:
tx_wrapper = DeployTransactionWrapper(transaction,status,starknet)
if internal_tx.tx_type == TransactionType.DEPLOY:
tx_wrapper = DeployTransactionWrapper(internal_tx, status, execution_info)
elif internal_tx.tx_type == TransactionType.INVOKE_FUNCTION:
tx_wrapper = InvokeTransactionWrapper(internal_tx, status, execution_info)
else:
tx_wrapper = InvokeTransactionWrapper(transaction,status,starknet)
raise StarknetDevnetException(message=f"Illegal tx_type: {internal_tx.tx_type}")

tx_wrapper.generate_receipt(execution_info)
if status == TxStatus.REJECTED:
assert error_message, "error_message must be present if tx rejected"
tx_wrapper.set_failure_reason(error_message)
else:
block_hash, block_number = await self.__generate_block(tx_wrapper.transaction, tx_wrapper.receipt)
tx_wrapper.set_block_data(block_hash, block_number)

return await self.__store_transaction(tx_wrapper, error_message)
numeric_hash = int(tx_wrapper.transaction_hash, 16)
self.__transaction_wrappers[numeric_hash] = tx_wrapper

def get_code(self, contract_address: int) -> dict:
"""Returns a `dict` with `abi` and `bytecode` of the contract at `contract_address`."""
Expand All @@ -296,7 +317,7 @@ async def get_storage_at(self, contract_address: int, key: int) -> str:
Returns the storage identified by `key`
from the contract at `contract_address`.
"""
state = await self.get_state()
state = await self.__get_state()
contract_states = state.state.contract_states

state = contract_states[contract_address]
Expand Down
Loading