diff --git a/eth/chains/base.py b/eth/chains/base.py index 76d54a1999..0264974a8e 100644 --- a/eth/chains/base.py +++ b/eth/chains/base.py @@ -892,3 +892,15 @@ async def coro_validate_receipt(self, receipt: Receipt, at_header: BlockHeader) -> None: raise NotImplementedError() + + async def coro_get_block_by_hash(self, + block_hash: Hash32) -> BaseBlock: + raise NotImplementedError() + + async def coro_get_block_by_header(self, + header: BlockHeader) -> BaseBlock: + raise NotImplementedError() + + async def coro_get_canonical_block_by_number(self, + block_number: BlockNumber) -> BaseBlock: + raise NotImplementedError() diff --git a/eth/tools/fixtures/helpers.py b/eth/tools/fixtures/helpers.py index 4943b88a08..b7fb14f488 100644 --- a/eth/tools/fixtures/helpers.py +++ b/eth/tools/fixtures/helpers.py @@ -146,12 +146,12 @@ def genesis_params_from_fixture(fixture): } -def new_chain_from_fixture(fixture): +def new_chain_from_fixture(fixture, chain_cls=MainnetChain): base_db = MemoryDB() vm_config = chain_vm_configuration(fixture) - ChainFromFixture = MainnetChain.configure( + ChainFromFixture = chain_cls.configure( 'ChainFromFixture', vm_configuration=vm_config, ) diff --git a/fixtures b/fixtures index f4faae91c5..47b09f42c0 160000 --- a/fixtures +++ b/fixtures @@ -1 +1 @@ -Subproject commit f4faae91c5ba192c3fd9b8cf418c24e627786312 +Subproject commit 47b09f42c0681548a00da5ab1c98808b368af49a diff --git a/p2p/events.py b/p2p/events.py new file mode 100644 index 0000000000..cc0ff32478 --- /dev/null +++ b/p2p/events.py @@ -0,0 +1,21 @@ +from typing import ( + Type, +) + +from lahja import ( + BaseEvent, + BaseRequestResponseEvent, +) + + +class PeerCountResponse(BaseEvent): + + def __init__(self, peer_count: int) -> None: + self.peer_count = peer_count + + +class PeerCountRequest(BaseRequestResponseEvent[PeerCountResponse]): + + @staticmethod + def expected_response_type() -> Type[PeerCountResponse]: + return PeerCountResponse diff --git a/p2p/peer.py b/p2p/peer.py index ae514a007f..f89f7751f4 100644 --- a/p2p/peer.py +++ b/p2p/peer.py @@ -52,6 +52,10 @@ from cancel_token import CancelToken, OperationCancelled +from lahja import ( + Endpoint, +) + from eth.chains.mainnet import MAINNET_NETWORK_ID from eth.chains.ropsten import ROPSTEN_NETWORK_ID from eth.constants import GENESIS_BLOCK_NUMBER @@ -101,6 +105,11 @@ MAC_LEN, ) +from .events import ( + PeerCountRequest, + PeerCountResponse, +) + if TYPE_CHECKING: from trinity.db.header import BaseAsyncHeaderDB # noqa: F401 from trinity.protocol.common.proto import ChainInfo # noqa: F401 @@ -784,6 +793,7 @@ def __init__(self, vm_configuration: Tuple[Tuple[int, Type[BaseVM]], ...], max_peers: int = DEFAULT_MAX_PEERS, token: CancelToken = None, + event_bus: Endpoint = None ) -> None: super().__init__(token) self.peer_class = peer_class @@ -794,6 +804,16 @@ def __init__(self, self.max_peers = max_peers self.connected_nodes: Dict[Node, BasePeer] = {} self._subscribers: List[PeerSubscriber] = [] + self.event_bus = event_bus + self.run_task(self.handle_peer_count_requests()) + + async def handle_peer_count_requests(self) -> None: + async for req in self.event_bus.stream(PeerCountRequest): + # We are listening for all `PeerCountRequest` events but we ensure to + # only send a `PeerCountResponse` to the callsite that made the request. + # We do that by retrieving a `BroadcastConfig` from the request via the + # `event.broadcast_config()` API. + self.event_bus.broadcast(PeerCountResponse(len(self)), req.broadcast_config()) def __len__(self) -> int: return len(self.connected_nodes) diff --git a/setup.py b/setup.py index c17aa8fa2c..214a6eda22 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ "ipython>=6.2.1,<7.0.0", "plyvel==1.0.5", "web3==4.4.1", - "lahja==0.6.1", + "lahja==0.8.0", ], 'test': [ "hypothesis==3.69.5", diff --git a/tests/conftest.py b/tests/conftest.py index 998e4b179d..3d2ee36917 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -78,8 +78,7 @@ def funded_address_initial_balance(): return to_wei(1000, 'ether') -@pytest.fixture -def chain_with_block_validation(base_db, genesis_state): +def _chain_with_block_validation(base_db, genesis_state, chain_cls=Chain): """ Return a Chain object containing just the genesis block. @@ -107,7 +106,8 @@ def chain_with_block_validation(base_db, genesis_state): "transaction_root": decode_hex("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421"), # noqa: E501 "uncles_hash": decode_hex("1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347") # noqa: E501 } - klass = Chain.configure( + + klass = chain_cls.configure( __name__='TestChain', vm_configuration=( (constants.GENESIS_BLOCK_NUMBER, SpuriousDragonVM), @@ -118,6 +118,11 @@ def chain_with_block_validation(base_db, genesis_state): return chain +@pytest.fixture +def chain_with_block_validation(base_db, genesis_state): + return _chain_with_block_validation(base_db, genesis_state) + + def import_block_without_validation(chain, block): return super(type(chain), chain).import_block(block, perform_validation=False) diff --git a/tests/trinity/conftest.py b/tests/trinity/conftest.py index a19a0267c5..637d72596b 100644 --- a/tests/trinity/conftest.py +++ b/tests/trinity/conftest.py @@ -5,8 +5,19 @@ import tempfile import uuid +from lahja import ( + EventBus, +) + +from eth.chains import ( + Chain, +) + from p2p.peer import PeerPool +from trinity.chains.coro import ( + AsyncChainMixin, +) from trinity.rpc.main import ( RPCServer, ) @@ -22,6 +33,13 @@ from trinity.utils.filesystem import ( is_under_path, ) +from tests.conftest import ( + _chain_with_block_validation, +) + + +class TestAsyncChain(Chain, AsyncChainMixin): + pass def pytest_addoption(parser): @@ -51,6 +69,19 @@ def event_loop(): loop.close() +@pytest.fixture(scope='module') +def event_bus(event_loop): + bus = EventBus() + endpoint = bus.create_endpoint('test') + bus.start(event_loop) + endpoint.connect(event_loop) + try: + yield endpoint + finally: + endpoint.stop() + bus.stop() + + @pytest.fixture(scope='session') def jsonrpc_ipc_pipe_path(): with tempfile.TemporaryDirectory() as temp_dir: @@ -64,11 +95,16 @@ def p2p_server(monkeypatch, jsonrpc_ipc_pipe_path): return Server(None, None, None, None, None, None, None) +@pytest.fixture +def chain_with_block_validation(base_db, genesis_state): + return _chain_with_block_validation(base_db, genesis_state, TestAsyncChain) + + @pytest.mark.asyncio @pytest.fixture async def ipc_server( monkeypatch, - p2p_server, + event_bus, jsonrpc_ipc_pipe_path, event_loop, chain_with_block_validation): @@ -77,7 +113,7 @@ async def ipc_server( the course of all tests. It yields the IPC server only for monkeypatching purposes ''' - rpc = RPCServer(chain_with_block_validation, p2p_server.peer_pool) + rpc = RPCServer(chain_with_block_validation, event_bus) ipc_server = IPCServer(rpc, jsonrpc_ipc_pipe_path, loop=event_loop) asyncio.ensure_future(ipc_server.run(), loop=event_loop) diff --git a/tests/trinity/core/chain-management/test_light_peer_chain.py b/tests/trinity/core/chain-management/test_light_peer_chain.py new file mode 100644 index 0000000000..9cabf18a25 --- /dev/null +++ b/tests/trinity/core/chain-management/test_light_peer_chain.py @@ -0,0 +1,20 @@ +from trinity.sync.light.service import ( + LightPeerChain +) +from trinity.plugins.builtin.light_peer_chain_bridge import ( + EventBusLightPeerChain, +) + + +# These tests may seem obvious but they safe us from runtime errors where +# changes are made to the `BaseLightPeerChain` that are then forgotton to +# implement on both derived chains. + +def test_can_instantiate_eventbus_light_peer_chain(): + chain = EventBusLightPeerChain(None) + assert chain is not None + + +def test_can_instantiate_light_peer_chain(): + chain = LightPeerChain(None, None) + assert chain is not None diff --git a/tests/trinity/core/json-rpc/test_ipc.py b/tests/trinity/core/json-rpc/test_ipc.py index f2c68134cb..019b940fe9 100644 --- a/tests/trinity/core/json-rpc/test_ipc.py +++ b/tests/trinity/core/json-rpc/test_ipc.py @@ -14,6 +14,11 @@ to_hex, ) +from p2p.events import ( + PeerCountRequest, + PeerCountResponse, +) + from trinity.utils.version import construct_trinity_client_identifier @@ -35,15 +40,6 @@ def build_request(method, params=[]): return json.dumps(request).encode() -class MockPeerPool: - - def __init__(self, peer_count=0): - self.peer_count = peer_count - - def __len__(self): - return self.peer_count - - def id_from_rpc_request(param): if isinstance(param, bytes): request = json.loads(param.decode()) @@ -68,6 +64,7 @@ async def get_ipc_response( jsonrpc_ipc_pipe_path, request_msg, event_loop): + assert wait_for(jsonrpc_ipc_pipe_path), "IPC server did not successfully start with IPC file" reader, writer = await asyncio.open_unix_connection(str(jsonrpc_ipc_pipe_path), loop=event_loop) @@ -399,18 +396,27 @@ async def test_eth_call_with_contract_on_ipc( assert result == expected +def mock_peer_count(count): + async def mock_event_bus_interaction(bus): + async for req in bus.stream(PeerCountRequest): + bus.broadcast(PeerCountResponse(count), req.broadcast_config()) + break + + return mock_event_bus_interaction + + @pytest.mark.asyncio @pytest.mark.parametrize( - 'request_msg, mock_peer_pool, expected', + 'request_msg, event_bus_setup_fn, expected', ( ( build_request('net_peerCount'), - MockPeerPool(peer_count=1), + mock_peer_count(1), {'result': '0x1', 'id': 3, 'jsonrpc': '2.0'}, ), ( build_request('net_peerCount'), - MockPeerPool(peer_count=0), + mock_peer_count(0), {'result': '0x0', 'id': 3, 'jsonrpc': '2.0'}, ), ), @@ -422,10 +428,17 @@ async def test_peer_pool_over_ipc( monkeypatch, jsonrpc_ipc_pipe_path, request_msg, - mock_peer_pool, + event_bus_setup_fn, + event_bus, expected, event_loop, ipc_server): - monkeypatch.setattr(ipc_server.rpc.modules['net'], '_peer_pool', mock_peer_pool) - result = await get_ipc_response(jsonrpc_ipc_pipe_path, request_msg, event_loop) + + asyncio.ensure_future(event_bus_setup_fn(event_bus)) + + result = await get_ipc_response( + jsonrpc_ipc_pipe_path, + request_msg, + event_loop + ) assert result == expected diff --git a/tests/trinity/integration/test_lightchain_integration.py b/tests/trinity/integration/test_lightchain_integration.py index 65967282e7..de1f3a6174 100644 --- a/tests/trinity/integration/test_lightchain_integration.py +++ b/tests/trinity/integration/test_lightchain_integration.py @@ -198,26 +198,26 @@ async def wait_for_header_sync(block_number): # https://ropsten.etherscan.io/block/11 header = headerdb.get_canonical_block_header_by_number(n) - body = await peer_chain.get_block_body_by_hash(header.hash) + body = await peer_chain.coro_get_block_body_by_hash(header.hash) assert len(body['transactions']) == 15 - receipts = await peer_chain.get_receipts(header.hash) + receipts = await peer_chain.coro_get_receipts(header.hash) assert len(receipts) == 15 assert encode_hex(keccak(rlp.encode(receipts[0]))) == ( '0xf709ed2c57efc18a1675e8c740f3294c9e2cb36ba7bb3b89d3ab4c8fef9d8860') assert len(peer_pool) == 1 peer = peer_pool.highest_td_peer - head = await peer_chain.get_block_header_by_hash(peer.head_hash) + head = await peer_chain.coro_get_block_header_by_hash(peer.head_hash) # In order to answer queries for contract code, geth needs the state trie entry for the block # we specify in the query, but because of fast sync we can only assume it has that for recent # blocks, so we use the current head to lookup the code for the contract below. # https://ropsten.etherscan.io/address/0x95a48dca999c89e4e284930d9b9af973a7481287 contract_addr = decode_hex('0x8B09D9ac6A4F7778fCb22852e879C7F3B2bEeF81') - contract_code = await peer_chain.get_contract_code(head.hash, contract_addr) + contract_code = await peer_chain.coro_get_contract_code(head.hash, contract_addr) assert encode_hex(contract_code) == '0x600060006000600060006000356000f1' - account = await peer_chain.get_account(head.hash, contract_addr) + account = await peer_chain.coro_get_account(head.hash, contract_addr) assert account.code_hash == keccak(contract_code) assert account.balance == 0 diff --git a/tests/trinity/json-fixtures-over-rpc/test_rpc_fixtures.py b/tests/trinity/json-fixtures-over-rpc/test_rpc_fixtures.py index 297a82b270..3700d1d0c9 100644 --- a/tests/trinity/json-fixtures-over-rpc/test_rpc_fixtures.py +++ b/tests/trinity/json-fixtures-over-rpc/test_rpc_fixtures.py @@ -25,6 +25,9 @@ should_run_slow_tests, ) +from trinity.chains.mainnet import ( + MainnetFullChain +) from trinity.rpc import RPCServer from trinity.rpc.format import ( empty_to_0x, @@ -199,21 +202,21 @@ def result_from_response(response_str): return (response.get('result', None), response.get('error', None)) -def call_rpc(rpc, method, params): +async def call_rpc(rpc, method, params): request = build_request(method, params) - response = rpc.execute(request) + response = await rpc.execute(request) return result_from_response(response) -def assert_rpc_result(rpc, method, params, expected): - result, error = call_rpc(rpc, method, params) +async def assert_rpc_result(rpc, method, params, expected): + result, error = await call_rpc(rpc, method, params) assert error is None assert result == expected return result -def validate_account_attribute(fixture_key, rpc_method, rpc, state, addr, at_block): - state_result, state_error = call_rpc(rpc, rpc_method, [addr, at_block]) +async def validate_account_attribute(fixture_key, rpc_method, rpc, state, addr, at_block): + state_result, state_error = await call_rpc(rpc, rpc_method, [addr, at_block]) assert state_result == state[fixture_key], "Invalid state - %s" % state_error @@ -224,19 +227,31 @@ def validate_account_attribute(fixture_key, rpc_method, rpc, state, addr, at_blo ) -def validate_account_state(rpc, state, addr, at_block): +async def validate_account_state(rpc, state, addr, at_block): standardized_state = fixture_state_in_rpc_format(state) for fixture_key, rpc_method in RPC_STATE_LOOKUPS: - validate_account_attribute(fixture_key, rpc_method, rpc, standardized_state, addr, at_block) + await validate_account_attribute( + fixture_key, + rpc_method, + rpc, + standardized_state, + addr, + at_block + ) for key in state['storage']: position = '0x0' if key == '0x' else key expected_storage = state['storage'][key] - assert_rpc_result(rpc, 'eth_getStorageAt', [addr, position, at_block], expected_storage) + await assert_rpc_result( + rpc, + 'eth_getStorageAt', + [addr, position, at_block], + expected_storage + ) -def validate_accounts(rpc, states, at_block='latest'): +async def validate_accounts(rpc, states, at_block='latest'): for addr in states: - validate_account_state(rpc, states[addr], addr, at_block) + await validate_account_state(rpc, states[addr], addr, at_block) def validate_rpc_block_vs_fixture(block, block_fixture): @@ -264,13 +279,13 @@ def is_by_hash(at_block): raise ValueError("Unrecognized 'at_block' value: %r" % at_block) -def validate_transaction_count(rpc, block_fixture, at_block): +async def validate_transaction_count(rpc, block_fixture, at_block): if is_by_hash(at_block): rpc_method = 'eth_getBlockTransactionCountByHash' else: rpc_method = 'eth_getBlockTransactionCountByNumber' expected_transaction_count = hex(len(block_fixture['transactions'])) - assert_rpc_result(rpc, rpc_method, [at_block], expected_transaction_count) + await assert_rpc_result(rpc, rpc_method, [at_block], expected_transaction_count) def validate_rpc_transaction_vs_fixture(transaction, fixture): @@ -282,74 +297,74 @@ def validate_rpc_transaction_vs_fixture(transaction, fixture): assert actual_transaction == expected -def validate_transaction_by_index(rpc, transaction_fixture, at_block, index): +async def validate_transaction_by_index(rpc, transaction_fixture, at_block, index): if is_by_hash(at_block): rpc_method = 'eth_getTransactionByBlockHashAndIndex' else: rpc_method = 'eth_getTransactionByBlockNumberAndIndex' - result, error = call_rpc(rpc, rpc_method, [at_block, hex(index)]) + result, error = await call_rpc(rpc, rpc_method, [at_block, hex(index)]) assert error is None validate_rpc_transaction_vs_fixture(result, transaction_fixture) -def validate_block(rpc, block_fixture, at_block): +async def validate_block(rpc, block_fixture, at_block): if is_by_hash(at_block): rpc_method = 'eth_getBlockByHash' else: rpc_method = 'eth_getBlockByNumber' # validate without transaction bodies - result, error = call_rpc(rpc, rpc_method, [at_block, False]) + result, error = await call_rpc(rpc, rpc_method, [at_block, False]) assert error is None validate_rpc_block_vs_fixture(result, block_fixture) assert len(result['transactions']) == len(block_fixture['transactions']) for index, transaction_fixture in enumerate(block_fixture['transactions']): - validate_transaction_by_index(rpc, transaction_fixture, at_block, index) + await validate_transaction_by_index(rpc, transaction_fixture, at_block, index) - validate_transaction_count(rpc, block_fixture, at_block) + await validate_transaction_count(rpc, block_fixture, at_block) # TODO validate transaction bodies - result, error = call_rpc(rpc, rpc_method, [at_block, True]) + result, error = await call_rpc(rpc, rpc_method, [at_block, True]) # assert error is None # assert result['transactions'] == block_fixture['transactions'] - validate_uncles(rpc, block_fixture, at_block) + await validate_uncles(rpc, block_fixture, at_block) -def validate_last_block(rpc, block_fixture): +async def validate_last_block(rpc, block_fixture): header = block_fixture['blockHeader'] - validate_block(rpc, block_fixture, 'latest') - validate_block(rpc, block_fixture, header['hash']) - validate_block(rpc, block_fixture, int(header['number'], 16)) + await validate_block(rpc, block_fixture, 'latest') + await validate_block(rpc, block_fixture, header['hash']) + await validate_block(rpc, block_fixture, int(header['number'], 16)) -def validate_uncle_count(rpc, block_fixture, at_block): +async def validate_uncle_count(rpc, block_fixture, at_block): if is_by_hash(at_block): rpc_method = 'eth_getUncleCountByBlockHash' else: rpc_method = 'eth_getUncleCountByBlockNumber' num_uncles = len(block_fixture['uncleHeaders']) - assert_rpc_result(rpc, rpc_method, [at_block], hex(num_uncles)) + await assert_rpc_result(rpc, rpc_method, [at_block], hex(num_uncles)) -def validate_uncle_headers(rpc, block_fixture, at_block): +async def validate_uncle_headers(rpc, block_fixture, at_block): if is_by_hash(at_block): rpc_method = 'eth_getUncleByBlockHashAndIndex' else: rpc_method = 'eth_getUncleByBlockNumberAndIndex' for idx, uncle in enumerate(block_fixture['uncleHeaders']): - result, error = call_rpc(rpc, rpc_method, [at_block, hex(idx)]) + result, error = await call_rpc(rpc, rpc_method, [at_block, hex(idx)]) assert error is None validate_rpc_block_vs_fixture_header(result, uncle) -def validate_uncles(rpc, block_fixture, at_block): - validate_uncle_count(rpc, block_fixture, at_block) - validate_uncle_headers(rpc, block_fixture, at_block) +async def validate_uncles(rpc, block_fixture, at_block): + await validate_uncle_count(rpc, block_fixture, at_block) + await validate_uncle_headers(rpc, block_fixture, at_block) @pytest.fixture @@ -369,13 +384,14 @@ def chain(chain_without_block_validation): return -def test_rpc_against_fixtures(chain, ipc_server, chain_fixture, fixture_data): - rpc = RPCServer(None) +@pytest.mark.asyncio +async def test_rpc_against_fixtures(chain, ipc_server, chain_fixture, fixture_data): + rpc = RPCServer(MainnetFullChain(None)) - setup_result, setup_error = call_rpc(rpc, 'evm_resetToGenesisFixture', [chain_fixture]) + setup_result, setup_error = await call_rpc(rpc, 'evm_resetToGenesisFixture', [chain_fixture]) assert setup_error is None and setup_result is True, "cannot load chain for %r" % fixture_data - validate_accounts(rpc, chain_fixture['pre']) + await validate_accounts(rpc, chain_fixture['pre']) for block_fixture in chain_fixture['blocks']: should_be_good_block = 'blockHeader' in block_fixture @@ -384,21 +400,21 @@ def test_rpc_against_fixtures(chain, ipc_server, chain_fixture, fixture_data): assert not should_be_good_block continue - block_result, block_error = call_rpc(rpc, 'evm_applyBlockFixture', [block_fixture]) + block_result, block_error = await call_rpc(rpc, 'evm_applyBlockFixture', [block_fixture]) if should_be_good_block: assert block_error is None assert block_result == block_fixture['rlp'] - validate_block(rpc, block_fixture, block_fixture['blockHeader']['hash']) + await validate_block(rpc, block_fixture, block_fixture['blockHeader']['hash']) else: assert block_error is not None if chain_fixture.get('lastblockhash', None): for block_fixture in chain_fixture['blocks']: if get_in(['blockHeader', 'hash'], block_fixture) == chain_fixture['lastblockhash']: - validate_last_block(rpc, block_fixture) + await validate_last_block(rpc, block_fixture) - validate_accounts(rpc, chain_fixture['postState']) - validate_accounts(rpc, chain_fixture['pre'], 'earliest') - validate_accounts(rpc, chain_fixture['pre'], 0) + await validate_accounts(rpc, chain_fixture['postState']) + await validate_accounts(rpc, chain_fixture['pre'], 'earliest') + await validate_accounts(rpc, chain_fixture['pre'], 0) diff --git a/tox.ini b/tox.ini index fade700a2a..9f45b9daac 100644 --- a/tox.ini +++ b/tox.ini @@ -28,7 +28,8 @@ commands= rpc-state-homestead: pytest {posargs:tests/trinity/json-fixtures-over-rpc/test_rpc_fixtures.py -k 'GeneralStateTests and not stQuadraticComplexityTest and Homestead'} rpc-state-eip150: pytest {posargs:tests/trinity/json-fixtures-over-rpc/test_rpc_fixtures.py -k 'GeneralStateTests and not stQuadraticComplexityTest and EIP150'} rpc-state-eip158: pytest {posargs:tests/trinity/json-fixtures-over-rpc/test_rpc_fixtures.py -k 'GeneralStateTests and not stQuadraticComplexityTest and EIP158'} - rpc-state-byzantium: pytest {posargs:tests/trinity/json-fixtures-over-rpc/test_rpc_fixtures.py -k 'GeneralStateTests and not stQuadraticComplexityTest and Byzantium'} + # The following test seems to consume a lot of memory. Restricting to 3 processes reduces crashes + rpc-state-byzantium: pytest -n3 {posargs:tests/trinity/json-fixtures-over-rpc/test_rpc_fixtures.py -k 'GeneralStateTests and not stQuadraticComplexityTest and Byzantium'} rpc-state-quadratic: pytest {posargs:tests/trinity/json-fixtures-over-rpc/test_rpc_fixtures.py -k 'GeneralStateTests and stQuadraticComplexityTest'} transactions: pytest {posargs:tests/json-fixtures/test_transactions.py} vm: pytest {posargs:tests/json-fixtures/test_virtual_machine.py} diff --git a/trinity/chains/coro.py b/trinity/chains/coro.py new file mode 100644 index 0000000000..f0faa7781f --- /dev/null +++ b/trinity/chains/coro.py @@ -0,0 +1,10 @@ +from trinity.utils.async_dispatch import ( + async_method, +) + + +class AsyncChainMixin: + + coro_get_canonical_block_by_number = async_method('get_canonical_block_by_number') + coro_get_block_by_hash = async_method('get_block_by_hash') + coro_get_block_by_header = async_method('get_block_by_header') diff --git a/trinity/chains/light.py b/trinity/chains/light.py index 5c2740373a..d8d3af0012 100644 --- a/trinity/chains/light.py +++ b/trinity/chains/light.py @@ -54,7 +54,7 @@ ) from trinity.sync.light.service import ( - LightPeerChain, + BaseLightPeerChain, ) if TYPE_CHECKING: @@ -64,14 +64,14 @@ class LightDispatchChain(BaseChain): """ Provide the :class:`BaseChain` API, even though only a - :class:`LightPeerChain` is syncing. Store results locally so that not + :class:`BaseLightPeerChain` is syncing. Store results locally so that not all requests hit the light peer network. """ ASYNC_TIMEOUT_SECONDS = 10 _loop = None - def __init__(self, headerdb: BaseHeaderDB, peer_chain: LightPeerChain) -> None: + def __init__(self, headerdb: BaseHeaderDB, peer_chain: BaseLightPeerChain) -> None: self._headerdb = headerdb self._peer_chain = peer_chain self._peer_chain_loop = asyncio.get_event_loop() @@ -135,14 +135,19 @@ def get_block(self) -> BaseBlock: raise NotImplementedError("Chain classes must implement " + inspect.stack()[0][3]) def get_block_by_hash(self, block_hash: Hash32) -> BaseBlock: + raise NotImplementedError("Use coro_get_block_by_hash") + + async def coro_get_block_by_hash(self, block_hash: Hash32) -> BaseBlock: header = self._headerdb.get_block_header_by_hash(block_hash) - return self.get_block_by_header(header) + return await self.get_block_by_header(header) def get_block_by_header(self, header: BlockHeader) -> BaseBlock: + raise NotImplementedError("Use coro_get_block_by_header") + + async def coro_get_block_by_header(self, header: BlockHeader) -> BaseBlock: # TODO check local cache, before hitting peer - block_body = self._run_async( - self._peer_chain.get_block_body_by_hash(header.hash) - ) + + block_body = await self._peer_chain.coro_get_block_body_by_hash(header.hash) block_class = self.get_vm_class_for_block_number(header.block_number).get_block_class() transactions = [ @@ -156,12 +161,15 @@ def get_block_by_header(self, header: BlockHeader) -> BaseBlock: ) def get_canonical_block_by_number(self, block_number: BlockNumber) -> BaseBlock: + raise NotImplementedError("Use coro_get_canonical_block_by_number") + + async def coro_get_canonical_block_by_number(self, block_number: BlockNumber) -> BaseBlock: """ Return the block with the given number from the canonical chain. Raises HeaderNotFound if it is not found. """ header = self._headerdb.get_canonical_block_header_by_number(block_number) - return self.get_block_by_header(header) + return await self.get_block_by_header(header) def get_canonical_block_hash(self, block_number: int) -> Hash32: return self._headerdb.get_canonical_block_hash(block_number) @@ -234,12 +242,3 @@ def validate_chain( chain: Tuple[BlockHeader, ...], seal_check_random_sample_rate: int = 1) -> None: raise NotImplementedError("Chain classes must implement " + inspect.stack()[0][3]) - - # - # Async utils - # - T = TypeVar('T') - - def _run_async(self, async_method: Coroutine[T, Any, Any]) -> T: - future = asyncio.run_coroutine_threadsafe(async_method, loop=self._peer_chain_loop) - return future.result(self.ASYNC_TIMEOUT_SECONDS) diff --git a/trinity/chains/mainnet.py b/trinity/chains/mainnet.py index 1e80a27161..feba3131da 100644 --- a/trinity/chains/mainnet.py +++ b/trinity/chains/mainnet.py @@ -1,9 +1,15 @@ from eth.chains.mainnet import ( BaseMainnetChain, + MainnetChain ) +from trinity.chains.coro import AsyncChainMixin from trinity.chains.light import LightDispatchChain +class MainnetFullChain(MainnetChain, AsyncChainMixin): + pass + + class MainnetLightDispatchChain(BaseMainnetChain, LightDispatchChain): pass diff --git a/trinity/chains/ropsten.py b/trinity/chains/ropsten.py index c5edda4651..3f30127065 100644 --- a/trinity/chains/ropsten.py +++ b/trinity/chains/ropsten.py @@ -1,9 +1,15 @@ from eth.chains.ropsten import ( BaseRopstenChain, + RopstenChain ) +from trinity.chains.coro import AsyncChainMixin from trinity.chains.light import LightDispatchChain +class RopstenFullChain(RopstenChain, AsyncChainMixin): + pass + + class RopstenLightDispatchChain(BaseRopstenChain, LightDispatchChain): pass diff --git a/trinity/main.py b/trinity/main.py index d84c37788b..0e11e4185d 100644 --- a/trinity/main.py +++ b/trinity/main.py @@ -76,6 +76,9 @@ from trinity.utils.profiling import ( setup_cprofiler, ) +from trinity.utils.shutdown import ( + exit_on_signal +) from trinity.utils.version import ( construct_trinity_client_identifier, ) @@ -260,6 +263,7 @@ def trinity_boot(args: Namespace, database_server_process, networking_process, plugin_manager, + main_endpoint, event_bus ) ) @@ -279,6 +283,7 @@ def trinity_boot(args: Namespace, database_server_process, networking_process, plugin_manager, + main_endpoint, event_bus ) @@ -287,6 +292,7 @@ def kill_trinity_gracefully(logger: logging.Logger, database_server_process: Any, networking_process: Any, plugin_manager: PluginManager, + main_endpoint: Endpoint, event_bus: EventBus, message: str="Trinity shudown complete\n") -> None: # When a user hits Ctrl+C in the terminal, the SIGINT is sent to all processes in the @@ -301,7 +307,8 @@ def kill_trinity_gracefully(logger: logging.Logger, # perform a non-gracefull shutdown if the process takes too long to terminate. logger.info('Keyboard Interrupt: Stopping') plugin_manager.shutdown() - event_bus.shutdown() + main_endpoint.stop() + event_bus.stop() kill_process_gracefully(database_server_process, logger) logger.info('DB server process (pid=%d) terminated', database_server_process.pid) # XXX: This short sleep here seems to avoid us hitting a deadlock when attempting to @@ -336,21 +343,6 @@ def _sigint_handler(*args: Any) -> None: raise -async def exit_on_signal(service_to_exit: BaseService) -> None: - loop = service_to_exit.get_event_loop() - sigint_received = asyncio.Event() - for sig in [signal.SIGINT, signal.SIGTERM]: - # TODO also support Windows - loop.add_signal_handler(sig, sigint_received.set) - - await sigint_received.wait() - try: - await service_to_exit.cancel() - service_to_exit._executor.shutdown(wait=True) - finally: - loop.stop() - - @setup_cprofiler('launch_node') @with_queued_logging def launch_node(args: Namespace, chain_config: ChainConfig, endpoint: Endpoint) -> None: diff --git a/trinity/nodes/base.py b/trinity/nodes/base.py index b6c490c612..47d2278719 100644 --- a/trinity/nodes/base.py +++ b/trinity/nodes/base.py @@ -1,10 +1,8 @@ from abc import abstractmethod -import asyncio from pathlib import Path from multiprocessing.managers import ( BaseManager, ) -from threading import Thread from typing import ( Type, ) @@ -17,23 +15,8 @@ from p2p.service import ( BaseService, ) -from trinity.chains import ( - ChainProxy, -) -from trinity.chains.header import ( - AsyncHeaderChainProxy, -) -from trinity.db.chain import ChainDBProxy -from trinity.db.base import DBProxy from trinity.db.header import ( AsyncHeaderDB, - AsyncHeaderDBProxy -) -from trinity.rpc.main import ( - RPCServer, -) -from trinity.rpc.ipc import ( - IPCServer, ) from trinity.config import ( ChainConfig, @@ -44,6 +27,9 @@ from trinity.extensibility.events import ( ResourceAvailableEvent ) +from trinity.utils.db_proxy import ( + create_db_manager +) class Node(BaseService): @@ -106,69 +92,5 @@ def notify_resource_available(self) -> None: resource_type=BaseChain )) - @property - def has_ipc_server(self) -> bool: - return bool(self._jsonrpc_ipc_path) - - def make_ipc_server(self, loop: asyncio.AbstractEventLoop) -> BaseService: - if self.has_ipc_server: - rpc = RPCServer(self.get_chain(), self.get_peer_pool()) - return IPCServer(rpc, self._jsonrpc_ipc_path, loop=loop) - else: - return None - async def _run(self) -> None: - if self.has_ipc_server: - # The RPC server needs its own thread, because it provides a synchcronous - # API which might call into p2p async methods. These sync->async calls - # deadlock if they are run in the same Thread and loop. - ipc_loop = self._make_new_loop_thread() - - self._ipc_server = self.make_ipc_server(ipc_loop) - - # keep a copy on self, for later shutdown - self._ipc_loop = ipc_loop - - asyncio.run_coroutine_threadsafe(self._ipc_server.run(), loop=ipc_loop) - await self.get_p2p_server().run() - - async def _cleanup(self) -> None: - # IPC Server requires special handling because it's running in its own loop & thread - if self.has_ipc_server: - await self._ipc_server.threadsafe_cancel() - # Stop the this IPCServer-specific event loop, so that the IPCServer thread will exit - self._ipc_loop.stop() - - def _make_new_loop_thread(self) -> asyncio.AbstractEventLoop: - new_loop = asyncio.new_event_loop() - - def start_loop(loop: asyncio.AbstractEventLoop) -> None: - asyncio.set_event_loop(loop) - loop.run_forever() - loop.close() - - thread = Thread(target=start_loop, args=(new_loop, )) - thread.start() - - return new_loop - - -def create_db_manager(ipc_path: Path) -> BaseManager: - """ - We're still using 'str' here on param ipc_path because an issue with - multi-processing not being able to interpret 'Path' objects correctly - """ - class DBManager(BaseManager): - pass - - # Typeshed definitions for multiprocessing.managers is incomplete, so ignore them for now: - # https://github.com/python/typeshed/blob/85a788dbcaa5e9e9a62e55f15d44530cd28ba830/stdlib/3/multiprocessing/managers.pyi#L3 - DBManager.register('get_db', proxytype=DBProxy) # type: ignore - DBManager.register('get_chaindb', proxytype=ChainDBProxy) # type: ignore - DBManager.register('get_chain', proxytype=ChainProxy) # type: ignore - DBManager.register('get_headerdb', proxytype=AsyncHeaderDBProxy) # type: ignore - DBManager.register('get_header_chain', proxytype=AsyncHeaderChainProxy) # type: ignore - - manager = DBManager(address=str(ipc_path)) # type: ignore - return manager diff --git a/trinity/nodes/full.py b/trinity/nodes/full.py index bba5a8fba9..91afa49b04 100644 --- a/trinity/nodes/full.py +++ b/trinity/nodes/full.py @@ -47,6 +47,7 @@ def get_p2p_server(self) -> Server: bootstrap_nodes=self._bootstrap_nodes, preferred_nodes=self._preferred_nodes, token=self.cancel_token, + event_bus=self._plugin_manager.event_bus_endpoint ) return self._p2p_server diff --git a/trinity/nodes/light.py b/trinity/nodes/light.py index b2193391b6..204ce68e50 100644 --- a/trinity/nodes/light.py +++ b/trinity/nodes/light.py @@ -74,6 +74,7 @@ def get_p2p_server(self) -> LightServer: preferred_nodes=self._preferred_nodes, use_discv5=self._use_discv5, token=self.cancel_token, + event_bus=self._plugin_manager.event_bus_endpoint, ) return self._p2p_server diff --git a/trinity/nodes/mainnet.py b/trinity/nodes/mainnet.py index d18a330a82..fa5d49df37 100644 --- a/trinity/nodes/mainnet.py +++ b/trinity/nodes/mainnet.py @@ -1,8 +1,5 @@ -from eth.chains.mainnet import ( - MainnetChain, -) - from trinity.chains.mainnet import ( + MainnetFullChain, MainnetLightDispatchChain, ) from trinity.nodes.light import LightNode @@ -10,7 +7,7 @@ class MainnetFullNode(FullNode): - chain_class = MainnetChain + chain_class = MainnetFullChain class MainnetLightNode(LightNode): diff --git a/trinity/nodes/ropsten.py b/trinity/nodes/ropsten.py index 0228169455..b2c40aca92 100644 --- a/trinity/nodes/ropsten.py +++ b/trinity/nodes/ropsten.py @@ -1,8 +1,5 @@ -from eth.chains.ropsten import ( - RopstenChain, -) - from trinity.chains.ropsten import ( + RopstenFullChain, RopstenLightDispatchChain, ) from trinity.nodes.light import LightNode @@ -10,7 +7,7 @@ class RopstenFullNode(FullNode): - chain_class = RopstenChain + chain_class = RopstenFullChain class RopstenLightNode(LightNode): diff --git a/trinity/plugins/builtin/json_rpc/__init__.py b/trinity/plugins/builtin/json_rpc/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/trinity/plugins/builtin/json_rpc/plugin.py b/trinity/plugins/builtin/json_rpc/plugin.py new file mode 100644 index 0000000000..722dff7f83 --- /dev/null +++ b/trinity/plugins/builtin/json_rpc/plugin.py @@ -0,0 +1,70 @@ +from argparse import ( + ArgumentParser, + _SubParsersAction, +) +import asyncio + +from trinity.constants import ( + SYNC_LIGHT +) +from trinity.extensibility import ( + BaseIsolatedPlugin, +) +from trinity.plugins.builtin.light_peer_chain_bridge.light_peer_chain_bridge import ( + EventBusLightPeerChain, +) +from trinity.rpc.main import ( + RPCServer, +) +from trinity.rpc.ipc import ( + IPCServer, +) +from trinity.utils.db_proxy import ( + create_db_manager +) +from trinity.utils.shutdown import ( + exit_on_signal +) + + +class JsonRpcServerPlugin(BaseIsolatedPlugin): + + @property + def name(self) -> str: + return "JSON-RPC Server" + + def should_start(self) -> bool: + return not self.context.args.disable_rpc + + def configure_parser(self, arg_parser: ArgumentParser, subparser: _SubParsersAction) -> None: + arg_parser.add_argument( + "--disable-rpc", + action="store_true", + help="Disables the JSON-RPC Server", + ) + + def start(self) -> None: + self.logger.info('JSON-RPC Server started') + self.context.event_bus.connect() + + db_manager = create_db_manager(self.context.chain_config.database_ipc_path) + db_manager.connect() + + chain_class = self.context.chain_config.node_class.chain_class + + if self.context.chain_config.sync_mode == SYNC_LIGHT: + header_db = db_manager.get_headerdb() # type: ignore + event_bus_light_peer_chain = EventBusLightPeerChain(self.context.event_bus) + chain = chain_class(header_db, peer_chain=event_bus_light_peer_chain) + else: + db = db_manager.get_db() # type: ignore + chain = chain_class(db) + + rpc = RPCServer(chain, self.context.event_bus) + ipc_server = IPCServer(rpc, self.context.chain_config.jsonrpc_ipc_path) + + loop = asyncio.get_event_loop() + asyncio.ensure_future(exit_on_signal(ipc_server, self.context.event_bus)) + asyncio.ensure_future(ipc_server.run()) + loop.run_forever() + loop.close() diff --git a/trinity/plugins/builtin/light_peer_chain_bridge/__init__.py b/trinity/plugins/builtin/light_peer_chain_bridge/__init__.py new file mode 100644 index 0000000000..de85e21e92 --- /dev/null +++ b/trinity/plugins/builtin/light_peer_chain_bridge/__init__.py @@ -0,0 +1,4 @@ +from .light_peer_chain_bridge import ( # noqa: F401 + EventBusLightPeerChain, + LightPeerChainEventBusHandler, +) diff --git a/trinity/plugins/builtin/light_peer_chain_bridge/light_peer_chain_bridge.py b/trinity/plugins/builtin/light_peer_chain_bridge/light_peer_chain_bridge.py new file mode 100644 index 0000000000..0475aa1dc2 --- /dev/null +++ b/trinity/plugins/builtin/light_peer_chain_bridge/light_peer_chain_bridge.py @@ -0,0 +1,257 @@ +from typing import ( + List, + Type, + TypeVar, +) + +from cancel_token import ( + CancelToken, +) + +from eth_typing import ( + Address, + Hash32, +) + +from eth.rlp.accounts import ( + Account, +) +from eth.rlp.headers import ( + BlockHeader, +) +from eth.rlp.receipts import ( + Receipt, +) + +from lahja import ( + BaseEvent, + BaseRequestResponseEvent, + Endpoint, +) + +from p2p.service import ( + BaseService, +) + +from trinity.utils.async_errors import ( + await_and_wrap_errors, +) +from trinity.rlp.block_body import BlockBody +from trinity.sync.light.service import ( + BaseLightPeerChain, +) + + +class BaseLightPeerChainResponse(BaseEvent): + + def __init__(self, error: Exception) -> None: + self.error = error + + +class BlockHeaderResponse(BaseLightPeerChainResponse): + + def __init__(self, block_header: BlockHeader, error: Exception=None) -> None: + super().__init__(error) + self.block_header = block_header + + +class BlockBodyResponse(BaseLightPeerChainResponse): + + def __init__(self, block_body: BlockBody, error: Exception=None) -> None: + super().__init__(error) + self.block_body = block_body + + +class ReceiptsResponse(BaseLightPeerChainResponse): + + def __init__(self, receipts: List[Receipt], error: Exception=None) -> None: + super().__init__(error) + self.receipts = receipts + + +class AccountResponse(BaseLightPeerChainResponse): + + def __init__(self, account: Account, error: Exception=None) -> None: + super().__init__(error) + self.account = account + + +class BytesResponse(BaseLightPeerChainResponse): + + def __init__(self, bytez: bytes, error: Exception=None) -> None: + super().__init__(error) + self.bytez = bytez + + +class GetBlockHeaderByHashRequest(BaseRequestResponseEvent[BlockHeaderResponse]): + + def __init__(self, block_hash: Hash32) -> None: + self.block_hash = block_hash + + @staticmethod + def expected_response_type() -> Type[BlockHeaderResponse]: + return BlockHeaderResponse + + +class GetBlockBodyByHashRequest(BaseRequestResponseEvent[BlockBodyResponse]): + + def __init__(self, block_hash: Hash32) -> None: + self.block_hash = block_hash + + @staticmethod + def expected_response_type() -> Type[BlockBodyResponse]: + return BlockBodyResponse + + +class GetReceiptsRequest(BaseRequestResponseEvent[ReceiptsResponse]): + + def __init__(self, block_hash: Hash32) -> None: + self.block_hash = block_hash + + @staticmethod + def expected_response_type() -> Type[ReceiptsResponse]: + return ReceiptsResponse + + +class GetAccountRequest(BaseRequestResponseEvent[AccountResponse]): + + def __init__(self, block_hash: Hash32, address: Address) -> None: + self.block_hash = block_hash + self.address = address + + @staticmethod + def expected_response_type() -> Type[AccountResponse]: + return AccountResponse + + +class GetContractCodeRequest(BaseRequestResponseEvent[BytesResponse]): + + def __init__(self, block_hash: Hash32, address: Address) -> None: + self.block_hash = block_hash + self.address = address + + @staticmethod + def expected_response_type() -> Type[BytesResponse]: + return BytesResponse + + +class LightPeerChainEventBusHandler(BaseService): + """ + The ``LightPeerChainEventBusHandler`` listens for certain events on the eventbus and + delegates them to the ``LightPeerChain`` to get answers. It then propagates responses + back to the caller. + """ + + def __init__(self, + chain: BaseLightPeerChain, + event_bus: Endpoint, + token: CancelToken = None) -> None: + super().__init__(token) + self.chain = chain + self.event_bus = event_bus + + async def _run(self) -> None: + self.logger.info("Running LightPeerChainEventBusHandler") + + self.run_daemon_task(self.handle_get_blockheader_by_hash_requests()) + self.run_daemon_task(self.handle_get_blockbody_by_hash_requests()) + self.run_daemon_task(self.handle_get_receipts_by_hash_requests()) + self.run_daemon_task(self.handle_get_account_requests()) + self.run_daemon_task(self.handle_get_contract_code_requests()) + + async def handle_get_blockheader_by_hash_requests(self) -> None: + async for event in self.event_bus.stream(GetBlockHeaderByHashRequest): + + val, error = await await_and_wrap_errors( + self.chain.coro_get_block_header_by_hash(event.block_hash) + ) + + self.event_bus.broadcast( + event.expected_response_type()(val, error), + event.broadcast_config() + ) + + async def handle_get_blockbody_by_hash_requests(self) -> None: + async for event in self.event_bus.stream(GetBlockBodyByHashRequest): + + val, error = await await_and_wrap_errors( + self.chain.coro_get_block_body_by_hash(event.block_hash) + ) + + self.event_bus.broadcast( + event.expected_response_type()(val, error), + event.broadcast_config() + ) + + async def handle_get_receipts_by_hash_requests(self) -> None: + async for event in self.event_bus.stream(GetReceiptsRequest): + + val, error = await await_and_wrap_errors(self.chain.coro_get_receipts(event.block_hash)) + + self.event_bus.broadcast( + event.expected_response_type()(val, error), + event.broadcast_config() + ) + + async def handle_get_account_requests(self) -> None: + async for event in self.event_bus.stream(GetAccountRequest): + + val, error = await await_and_wrap_errors( + self.chain.coro_get_account(event.block_hash, event.address) + ) + + self.event_bus.broadcast( + event.expected_response_type()(val, error), + event.broadcast_config() + ) + + async def handle_get_contract_code_requests(self) -> None: + + async for event in self.event_bus.stream(GetContractCodeRequest): + + val, error = await await_and_wrap_errors( + self.chain.coro_get_contract_code(event.block_hash, event.address) + ) + + self.event_bus.broadcast( + event.expected_response_type()(val, error), + event.broadcast_config() + ) + + +class EventBusLightPeerChain(BaseLightPeerChain): + """ + The ``EventBusLightPeerChain`` is an implementation of the ``BaseLightPeerChain`` that can + be used from within any process. + """ + + def __init__(self, event_bus: Endpoint) -> None: + self.event_bus = event_bus + + async def coro_get_block_header_by_hash(self, block_hash: Hash32) -> BlockHeader: + event = GetBlockHeaderByHashRequest(block_hash) + return self._pass_or_raise(await self.event_bus.request(event)).block_header + + async def coro_get_block_body_by_hash(self, block_hash: Hash32) -> BlockBody: + event = GetBlockBodyByHashRequest(block_hash) + return self._pass_or_raise(await self.event_bus.request(event)).block_body + + async def coro_get_receipts(self, block_hash: Hash32) -> List[Receipt]: + event = GetReceiptsRequest(block_hash) + return self._pass_or_raise(await self.event_bus.request(event)).receipts + + async def coro_get_account(self, block_hash: Hash32, address: Address) -> Account: + event = GetAccountRequest(block_hash, address) + return self._pass_or_raise(await self.event_bus.request(event)).account + + async def coro_get_contract_code(self, block_hash: Hash32, address: Address) -> bytes: + event = GetContractCodeRequest(block_hash, address) + return self._pass_or_raise(await self.event_bus.request(event)).bytez + + TResponse = TypeVar("TResponse", bound=BaseLightPeerChainResponse) + + def _pass_or_raise(self, response: TResponse) -> TResponse: + if response.error is not None: + raise response.error + + return response diff --git a/trinity/plugins/builtin/light_peer_chain_bridge/plugin.py b/trinity/plugins/builtin/light_peer_chain_bridge/plugin.py new file mode 100644 index 0000000000..f85ebea818 --- /dev/null +++ b/trinity/plugins/builtin/light_peer_chain_bridge/plugin.py @@ -0,0 +1,55 @@ +import asyncio +from typing import ( + cast +) + +from eth.chains.base import ( + BaseChain +) + +from trinity.constants import ( + SYNC_LIGHT +) +from trinity.extensibility import ( + BaseEvent, + BasePlugin, +) +from trinity.chains.light import ( + LightDispatchChain, +) +from trinity.extensibility.events import ( + ResourceAvailableEvent +) +from trinity.plugins.builtin.light_peer_chain_bridge import ( + LightPeerChainEventBusHandler +) + + +class LightPeerChainBridgePlugin(BasePlugin): + """ + The ``LightPeerChainBridgePlugin`` runs in the ``networking`` process and acts as a bridge + between other processes and the ``LightPeerChain``. + It runs only in ``light`` mode. + Other plugins can instantiate the ``EventBusLightPeerChain`` from separate processes to + interact with the ``LightPeerChain`` indirectly. + """ + + chain: BaseChain = None + + @property + def name(self) -> str: + return "LightPeerChain Bridge" + + def should_start(self) -> bool: + return self.chain is not None and self.context.chain_config.sync_mode == SYNC_LIGHT + + def handle_event(self, activation_event: BaseEvent) -> None: + if isinstance(activation_event, ResourceAvailableEvent): + if activation_event.resource_type is BaseChain: + self.chain = activation_event.resource + + def start(self) -> None: + self.logger.info('LightPeerChain Bridge started') + chain = cast(LightDispatchChain, self.chain) + handler = LightPeerChainEventBusHandler(chain._peer_chain, self.context.event_bus) + asyncio.ensure_future(handler.run()) diff --git a/trinity/plugins/registry.py b/trinity/plugins/registry.py index a86db7f041..8281e1ff5d 100644 --- a/trinity/plugins/registry.py +++ b/trinity/plugins/registry.py @@ -6,9 +6,15 @@ from trinity.plugins.builtin.fix_unclean_shutdown.plugin import ( FixUncleanShutdownPlugin ) +from trinity.plugins.builtin.json_rpc.plugin import ( + JsonRpcServerPlugin, +) from trinity.plugins.builtin.tx_pool.plugin import ( TxPlugin, ) +from trinity.plugins.builtin.light_peer_chain_bridge.plugin import ( + LightPeerChainBridgePlugin +) def is_ipython_available() -> bool: @@ -27,5 +33,7 @@ def is_ipython_available() -> bool: ENABLED_PLUGINS = [ AttachPlugin() if is_ipython_available() else AttachPlugin(use_ipython=False), FixUncleanShutdownPlugin(), + JsonRpcServerPlugin(), + LightPeerChainBridgePlugin(), TxPlugin(), ] diff --git a/trinity/rpc/format.py b/trinity/rpc/format.py index 067e70ea13..7b5f4eab36 100644 --- a/trinity/rpc/format.py +++ b/trinity/rpc/format.py @@ -1,3 +1,4 @@ +import asyncio import functools from typing import ( Any, @@ -21,7 +22,7 @@ import rlp from eth.chains.base import ( - BaseChain + AsyncChain ) from eth.constants import ( CREATE_CONTRACT_ADDRESS, @@ -102,7 +103,7 @@ def header_to_dict(header: BlockHeader) -> Dict[str, str]: def block_to_dict(block: BaseBlock, - chain: BaseChain, + chain: AsyncChain, include_transactions: bool) -> Dict[str, Union[str, List[str]]]: header_dict = header_to_dict(block.header) @@ -125,13 +126,22 @@ def block_to_dict(block: BaseBlock, def format_params(*formatters: Any) -> Callable[..., Any]: def decorator(func: Callable[..., Any]) -> Callable[..., Any]: - @functools.wraps(func) - def formatted_func(self: Any, *args: Any) -> Callable[..., Any]: - if len(formatters) != len(args): - raise TypeError("could not apply %d formatters to %r" % (len(formatters), args)) - formatted = (formatter(arg) for formatter, arg in zip(formatters, args)) - return func(self, *formatted) - return formatted_func + if asyncio.iscoroutinefunction(func): + @functools.wraps(func) + async def async_formatted_func(self: Any, *args: Any) -> Callable[..., Any]: + if len(formatters) != len(args): + raise TypeError("could not apply %d formatters to %r" % (len(formatters), args)) + formatted = (formatter(arg) for formatter, arg in zip(formatters, args)) + return await func(self, *formatted) + return async_formatted_func + else: + @functools.wraps(func) + def formatted_func(self: Any, *args: Any) -> Callable[..., Any]: + if len(formatters) != len(args): + raise TypeError("could not apply %d formatters to %r" % (len(formatters), args)) + formatted = (formatter(arg) for formatter, arg in zip(formatters, args)) + return func(self, *formatted) + return formatted_func return decorator diff --git a/trinity/rpc/ipc.py b/trinity/rpc/ipc.py index 01012b703b..576d2bd475 100644 --- a/trinity/rpc/ipc.py +++ b/trinity/rpc/ipc.py @@ -99,7 +99,7 @@ async def connection_loop(execute_rpc: Callable[[Any], Any], continue try: - result = execute_rpc(request) + result = await execute_rpc(request) except Exception as e: logger.exception("Unrecognized exception while executing RPC") await cancel_token.cancellable_wait( diff --git a/trinity/rpc/main.py b/trinity/rpc/main.py index c7eaa5e705..97e3b32773 100644 --- a/trinity/rpc/main.py +++ b/trinity/rpc/main.py @@ -12,11 +12,11 @@ ) from eth.chains.base import ( - BaseChain + AsyncChain, ) -from p2p.peer import ( - PeerPool +from lahja import ( + Endpoint ) from trinity.rpc.modules import ( @@ -74,11 +74,11 @@ class RPCServer: Web3, ) - def __init__(self, chain: BaseChain=None, peer_pool: PeerPool=None) -> None: + def __init__(self, chain: AsyncChain=None, event_bus: Endpoint=None) -> None: self.modules: Dict[str, RPCModule] = {} self.chain = chain for M in self.module_classes: - self.modules[M.__name__.lower()] = M(chain, peer_pool) + self.modules[M.__name__.lower()] = M(chain, event_bus) if len(self.modules) != len(self.module_classes): raise ValueError("apparent name conflict in RPC module_classes", self.module_classes) @@ -101,9 +101,9 @@ def _lookup_method(self, rpc_method: str) -> Any: except AttributeError: raise ValueError("Method not implemented: %r" % rpc_method) - def _get_result(self, - request: Dict[str, Any], - debug: bool=False) -> Tuple[Any, Union[Exception, str]]: + async def _get_result(self, + request: Dict[str, Any], + debug: bool=False) -> Tuple[Any, Union[Exception, str]]: ''' :returns: (result, error) - result is None if error is provided. Error must be convertable to string with ``str(error)``. @@ -116,7 +116,7 @@ def _get_result(self, method = self._lookup_method(request['method']) params = request.get('params', []) - result = method(*params) + result = await method(*params) if request['method'] == 'evm_resetToGenesisFixture': self.chain, result = result, True @@ -135,19 +135,19 @@ def _get_result(self, else: return result, None - def execute(self, request: Dict[str, Any]) -> str: + async def execute(self, request: Dict[str, Any]) -> str: ''' The key entry point for all incoming requests ''' - result, error = self._get_result(request) + result, error = await self._get_result(request) return generate_response(request, result, error) @property - def chain(self) -> BaseChain: + def chain(self) -> AsyncChain: return self.__chain @chain.setter - def chain(self, new_chain: BaseChain) -> None: + def chain(self, new_chain: AsyncChain) -> None: self.__chain = new_chain for module in self.modules.values(): module.set_chain(new_chain) diff --git a/trinity/rpc/modules/eth.py b/trinity/rpc/modules/eth.py index a1549ba3ad..c4dfaf03dd 100644 --- a/trinity/rpc/modules/eth.py +++ b/trinity/rpc/modules/eth.py @@ -23,7 +23,7 @@ ZERO_ADDRESS, ) from eth.chains.base import ( - BaseChain, + AsyncChain, ) from eth.rlp.blocks import ( BaseBlock @@ -55,45 +55,47 @@ ) -def get_header(chain: BaseChain, at_block: Union[str, int]) -> BlockHeader: +async def get_header(chain: AsyncChain, at_block: Union[str, int]) -> BlockHeader: if at_block == 'pending': raise NotImplementedError("RPC interface does not support the 'pending' block at this time") elif at_block == 'latest': at_header = chain.get_canonical_head() elif at_block == 'earliest': # TODO find if genesis block can be non-zero. Why does 'earliest' option even exist? - at_header = chain.get_canonical_block_by_number(0).header + block = await chain.coro_get_canonical_block_by_number(0) + at_header = block.header # mypy doesn't have user defined type guards yet # https://github.com/python/mypy/issues/5206 elif is_integer(at_block) and at_block >= 0: # type: ignore - at_header = chain.get_canonical_block_by_number(at_block).header + block = await chain.coro_get_canonical_block_by_number(0) + at_header = block.header else: raise TypeError("Unrecognized block reference: %r" % at_block) return at_header -def account_db_at_block(chain: BaseChain, - at_block: Union[str, int], - read_only: bool=True) ->BaseAccountDB: - at_header = get_header(chain, at_block) +async def account_db_at_block(chain: AsyncChain, + at_block: Union[str, int], + read_only: bool=True) ->BaseAccountDB: + at_header = await get_header(chain, at_block) vm = chain.get_vm(at_header) return vm.state.account_db -def get_block_at_number(chain: BaseChain, at_block: Union[str, int]) -> BaseBlock: +async def get_block_at_number(chain: AsyncChain, at_block: Union[str, int]) -> BaseBlock: # mypy doesn't have user defined type guards yet # https://github.com/python/mypy/issues/5206 if is_integer(at_block) and at_block >= 0: # type: ignore # optimization to avoid requesting block, then header, then block again - return chain.get_canonical_block_by_number(at_block) + return await chain.coro_get_canonical_block_by_number(at_block) else: - at_header = get_header(chain, at_block) - return chain.get_block_by_header(at_header) + at_header = await get_header(chain, at_block) + return await chain.coro_get_block_by_header(at_header) def dict_to_spoof_transaction( - chain: BaseChain, + chain: AsyncChain, header: BlockHeader, transaction_dict: Dict[str, Any]) -> SpoofTransaction: """ @@ -129,134 +131,136 @@ class Eth(RPCModule): Any attribute without an underscore is publicly accessible. ''' - def accounts(self) -> List[str]: + async def accounts(self) -> List[str]: # trinity does not manage accounts for the user return [] - def blockNumber(self) -> str: + async def blockNumber(self) -> str: num = self._chain.get_canonical_head().block_number return hex(num) @format_params(identity, to_int_if_hex) - def call(self, txn_dict: Dict[str, Any], at_block: Union[str, int]) -> str: - header = get_header(self._chain, at_block) + async def call(self, txn_dict: Dict[str, Any], at_block: Union[str, int]) -> str: + header = await get_header(self._chain, at_block) validate_transaction_call_dict(txn_dict, self._chain.get_vm(header)) transaction = dict_to_spoof_transaction(self._chain, header, txn_dict) result = self._chain.get_transaction_result(transaction, header) return encode_hex(result) - def coinbase(self) -> Hash32: + async def coinbase(self) -> Hash32: raise NotImplementedError() @format_params(identity, to_int_if_hex) - def estimateGas(self, txn_dict: Dict[str, Any], at_block: Union[str, int]) -> str: - header = get_header(self._chain, at_block) + async def estimateGas(self, txn_dict: Dict[str, Any], at_block: Union[str, int]) -> str: + header = await get_header(self._chain, at_block) validate_transaction_gas_estimation_dict(txn_dict, self._chain.get_vm(header)) transaction = dict_to_spoof_transaction(self._chain, header, txn_dict) gas = self._chain.estimate_gas(transaction, header) return hex(gas) - def gasPrice(self) -> int: + async def gasPrice(self) -> int: raise NotImplementedError() @format_params(decode_hex, to_int_if_hex) - def getBalance(self, address: Address, at_block: Union[str, int]) -> str: - account_db = account_db_at_block(self._chain, at_block) + async def getBalance(self, address: Address, at_block: Union[str, int]) -> str: + account_db = await account_db_at_block(self._chain, at_block) balance = account_db.get_balance(address) return hex(balance) @format_params(decode_hex, identity) - def getBlockByHash(self, - block_hash: Hash32, - include_transactions: bool) -> Dict[str, Union[str, List[str]]]: - block = self._chain.get_block_by_hash(block_hash) + async def getBlockByHash(self, + block_hash: Hash32, + include_transactions: bool) -> Dict[str, Union[str, List[str]]]: + block = await self._chain.coro_get_block_by_hash(block_hash) return block_to_dict(block, self._chain, include_transactions) @format_params(to_int_if_hex, identity) - def getBlockByNumber(self, - at_block: Union[str, int], - include_transactions: bool) -> Dict[str, Union[str, List[str]]]: - block = get_block_at_number(self._chain, at_block) + async def getBlockByNumber(self, + at_block: Union[str, int], + include_transactions: bool) -> Dict[str, Union[str, List[str]]]: + block = await get_block_at_number(self._chain, at_block) return block_to_dict(block, self._chain, include_transactions) @format_params(decode_hex) - def getBlockTransactionCountByHash(self, block_hash: Hash32) -> str: - block = self._chain.get_block_by_hash(block_hash) + async def getBlockTransactionCountByHash(self, block_hash: Hash32) -> str: + block = await self._chain.coro_get_block_by_hash(block_hash) return hex(len(block.transactions)) @format_params(to_int_if_hex) - def getBlockTransactionCountByNumber(self, at_block: Union[str, int]) -> str: - block = get_block_at_number(self._chain, at_block) + async def getBlockTransactionCountByNumber(self, at_block: Union[str, int]) -> str: + block = await get_block_at_number(self._chain, at_block) return hex(len(block.transactions)) @format_params(decode_hex, to_int_if_hex) - def getCode(self, address: Address, at_block: Union[str, int]) -> str: - account_db = account_db_at_block(self._chain, at_block) + async def getCode(self, address: Address, at_block: Union[str, int]) -> str: + account_db = await account_db_at_block(self._chain, at_block) code = account_db.get_code(address) return encode_hex(code) @format_params(decode_hex, to_int_if_hex, to_int_if_hex) - def getStorageAt(self, address: Address, position: int, at_block: Union[str, int]) -> str: + async def getStorageAt(self, address: Address, position: int, at_block: Union[str, int]) -> str: if not is_integer(position) or position < 0: raise TypeError("Position of storage must be a whole number, but was: %r" % position) - account_db = account_db_at_block(self._chain, at_block) + account_db = await account_db_at_block(self._chain, at_block) stored_val = account_db.get_storage(address, position) return encode_hex(int_to_big_endian(stored_val)) @format_params(decode_hex, to_int_if_hex) - def getTransactionByBlockHashAndIndex(self, block_hash: Hash32, index: int) -> Dict[str, str]: - block = self._chain.get_block_by_hash(block_hash) + async def getTransactionByBlockHashAndIndex(self, + block_hash: Hash32, + index: int) -> Dict[str, str]: + block = await self._chain.coro_get_block_by_hash(block_hash) transaction = block.transactions[index] return transaction_to_dict(transaction) @format_params(to_int_if_hex, to_int_if_hex) - def getTransactionByBlockNumberAndIndex(self, - at_block: Union[str, int], - index: int) -> Dict[str, str]: - block = get_block_at_number(self._chain, at_block) + async def getTransactionByBlockNumberAndIndex(self, + at_block: Union[str, int], + index: int) -> Dict[str, str]: + block = await get_block_at_number(self._chain, at_block) transaction = block.transactions[index] return transaction_to_dict(transaction) @format_params(decode_hex, to_int_if_hex) - def getTransactionCount(self, address: Address, at_block: Union[str, int]) -> str: - account_db = account_db_at_block(self._chain, at_block) + async def getTransactionCount(self, address: Address, at_block: Union[str, int]) -> str: + account_db = await account_db_at_block(self._chain, at_block) nonce = account_db.get_nonce(address) return hex(nonce) @format_params(decode_hex) - def getUncleCountByBlockHash(self, block_hash: Hash32) -> str: - block = self._chain.get_block_by_hash(block_hash) + async def getUncleCountByBlockHash(self, block_hash: Hash32) -> str: + block = await self._chain.coro_get_block_by_hash(block_hash) return hex(len(block.uncles)) @format_params(to_int_if_hex) - def getUncleCountByBlockNumber(self, at_block: Union[str, int]) -> str: - block = get_block_at_number(self._chain, at_block) + async def getUncleCountByBlockNumber(self, at_block: Union[str, int]) -> str: + block = await get_block_at_number(self._chain, at_block) return hex(len(block.uncles)) @format_params(decode_hex, to_int_if_hex) - def getUncleByBlockHashAndIndex(self, block_hash: Hash32, index: int) -> Dict[str, str]: - block = self._chain.get_block_by_hash(block_hash) + async def getUncleByBlockHashAndIndex(self, block_hash: Hash32, index: int) -> Dict[str, str]: + block = await self._chain.coro_get_block_by_hash(block_hash) uncle = block.uncles[index] return header_to_dict(uncle) @format_params(to_int_if_hex, to_int_if_hex) - def getUncleByBlockNumberAndIndex(self, - at_block: Union[str, int], - index: int) -> Dict[str, str]: - block = get_block_at_number(self._chain, at_block) + async def getUncleByBlockNumberAndIndex(self, + at_block: Union[str, int], + index: int) -> Dict[str, str]: + block = await get_block_at_number(self._chain, at_block) uncle = block.uncles[index] return header_to_dict(uncle) - def hashrate(self) -> str: + async def hashrate(self) -> str: raise NotImplementedError() - def mining(self) -> bool: + async def mining(self) -> bool: return False - def protocolVersion(self) -> str: + async def protocolVersion(self) -> str: return "63" - def syncing(self) -> bool: + async def syncing(self) -> bool: raise NotImplementedError() diff --git a/trinity/rpc/modules/evm.py b/trinity/rpc/modules/evm.py index 77bd139a89..5a08160923 100644 --- a/trinity/rpc/modules/evm.py +++ b/trinity/rpc/modules/evm.py @@ -25,16 +25,16 @@ class EVM(RPCModule): @format_params(normalize_blockchain_fixtures) - def resetToGenesisFixture(self, chain_info: Any) -> Chain: + async def resetToGenesisFixture(self, chain_info: Any) -> Chain: ''' This method is a special case. It returns a new chain object which is then replaced inside :class:`~trinity.rpc.main.RPCServer` for all future calls. ''' - return new_chain_from_fixture(chain_info) + return new_chain_from_fixture(chain_info, type(self._chain)) @format_params(normalize_block) - def applyBlockFixture(self, block_info: Any) -> str: + async def applyBlockFixture(self, block_info: Any) -> str: ''' This method is a special case. It returns a new chain object which is then replaced inside :class:`~trinity.rpc.main.RPCServer` diff --git a/trinity/rpc/modules/main.py b/trinity/rpc/modules/main.py index db51f6444a..b54069f253 100644 --- a/trinity/rpc/modules/main.py +++ b/trinity/rpc/modules/main.py @@ -1,18 +1,18 @@ from eth.chains.base import ( - BaseChain + AsyncChain ) -from p2p.peer import ( - PeerPool, +from lahja import ( + Endpoint ) class RPCModule: _chain = None - def __init__(self, chain: BaseChain, peer_pool: PeerPool) -> None: + def __init__(self, chain: AsyncChain, event_bus: Endpoint) -> None: self._chain = chain - self._peer_pool = peer_pool + self._event_bus = event_bus - def set_chain(self, chain: BaseChain) -> None: + def set_chain(self, chain: AsyncChain) -> None: self._chain = chain diff --git a/trinity/rpc/modules/net.py b/trinity/rpc/modules/net.py index 1b24219cfb..a160d7caa8 100644 --- a/trinity/rpc/modules/net.py +++ b/trinity/rpc/modules/net.py @@ -1,22 +1,26 @@ +from p2p.events import ( + PeerCountRequest +) from trinity.rpc.modules import ( RPCModule, ) class Net(RPCModule): - def version(self) -> str: + async def version(self) -> str: """ Returns the current network ID. """ return str(self._chain.network_id) - def peerCount(self) -> str: + async def peerCount(self) -> str: """ Return the number of peers that are currently connected to the node """ - return hex(len(self._peer_pool)) # type: ignore + response = await self._event_bus.request(PeerCountRequest()) + return hex(response.peer_count) - def listening(self) -> bool: + async def listening(self) -> bool: """ Return `True` if the client is actively listening for network connections """ diff --git a/trinity/rpc/modules/web3.py b/trinity/rpc/modules/web3.py index 724dcd846e..f3deae1252 100644 --- a/trinity/rpc/modules/web3.py +++ b/trinity/rpc/modules/web3.py @@ -9,13 +9,13 @@ class Web3(RPCModule): - def clientVersion(self) -> str: + async def clientVersion(self) -> str: """ Returns the current client version. """ return construct_trinity_client_identifier() - def sha3(self, data: str) -> str: + async def sha3(self, data: str) -> str: """ Returns Keccak-256 of the given data. """ diff --git a/trinity/server.py b/trinity/server.py index 5781cba59b..92d49d7525 100644 --- a/trinity/server.py +++ b/trinity/server.py @@ -13,6 +13,10 @@ from cancel_token import CancelToken, OperationCancelled +from lahja import ( + Endpoint +) + from eth.chains import AsyncChain from eth_typing import BlockNumber @@ -86,9 +90,11 @@ def __init__(self, bootstrap_nodes: Tuple[Node, ...] = None, preferred_nodes: Sequence[Node] = None, use_discv5: bool = False, + event_bus: Endpoint = None, token: CancelToken = None, ) -> None: super().__init__(token) + self.event_bus = event_bus self.headerdb = headerdb self.chaindb = chaindb self.chain = chain @@ -137,6 +143,7 @@ def _make_peer_pool(self) -> PeerPool: self.chain.get_vm_configuration(), max_peers=self.max_peers, token=self.cancel_token, + event_bus=self.event_bus, ) async def _run(self) -> None: diff --git a/trinity/sync/light/service.py b/trinity/sync/light/service.py index c49e8c7412..985c8586bb 100644 --- a/trinity/sync/light/service.py +++ b/trinity/sync/light/service.py @@ -1,3 +1,7 @@ +from abc import ( + ABC, + abstractmethod, +) import asyncio from functools import ( partial, @@ -71,7 +75,30 @@ from trinity.utils.les import gen_request_id -class LightPeerChain(PeerSubscriber, BaseService): +class BaseLightPeerChain(ABC): + + @abstractmethod + async def coro_get_block_header_by_hash(self, block_hash: Hash32) -> BlockHeader: + pass + + @abstractmethod + async def coro_get_block_body_by_hash(self, block_hash: Hash32) -> BlockBody: + pass + + @abstractmethod + async def coro_get_receipts(self, block_hash: Hash32) -> List[Receipt]: + pass + + @abstractmethod + async def coro_get_account(self, block_hash: Hash32, address: Address) -> Account: + pass + + @abstractmethod + async def coro_get_contract_code(self, block_hash: Hash32, address: Address) -> bytes: + pass + + +class LightPeerChain(PeerSubscriber, BaseService, BaseLightPeerChain): reply_timeout = REPLY_TIMEOUT headerdb: BaseAsyncHeaderDB = None @@ -125,7 +152,7 @@ def callback(r: protocol._DecodedMsgType) -> None: @alru_cache(maxsize=1024, cache_exceptions=False) @service_timeout(COMPLETION_TIMEOUT) - async def get_block_header_by_hash(self, block_hash: Hash32) -> BlockHeader: + async def coro_get_block_header_by_hash(self, block_hash: Hash32) -> BlockHeader: """ :param block_hash: hash of the header to retrieve @@ -140,7 +167,7 @@ async def get_block_header_by_hash(self, block_hash: Hash32) -> BlockHeader: @alru_cache(maxsize=1024, cache_exceptions=False) @service_timeout(COMPLETION_TIMEOUT) - async def get_block_body_by_hash(self, block_hash: Hash32) -> BlockBody: + async def coro_get_block_body_by_hash(self, block_hash: Hash32) -> BlockBody: peer = cast(LESPeer, self.peer_pool.highest_td_peer) self.logger.debug("Fetching block %s from %s", encode_hex(block_hash), peer) request_id = gen_request_id() @@ -154,7 +181,7 @@ async def get_block_body_by_hash(self, block_hash: Hash32) -> BlockBody: @alru_cache(maxsize=1024, cache_exceptions=False) @service_timeout(COMPLETION_TIMEOUT) - async def get_receipts(self, block_hash: Hash32) -> List[Receipt]: + async def coro_get_receipts(self, block_hash: Hash32) -> List[Receipt]: peer = cast(LESPeer, self.peer_pool.highest_td_peer) self.logger.debug("Fetching %s receipts from %s", encode_hex(block_hash), peer) request_id = gen_request_id() @@ -169,7 +196,7 @@ async def get_receipts(self, block_hash: Hash32) -> List[Receipt]: @alru_cache(maxsize=1024, cache_exceptions=False) @service_timeout(COMPLETION_TIMEOUT) - async def get_account(self, block_hash: Hash32, address: Address) -> Account: + async def coro_get_account(self, block_hash: Hash32, address: Address) -> Account: return await self._retry_on_bad_response( partial(self._get_account_from_peer, block_hash, address) ) @@ -194,7 +221,7 @@ async def _get_account_from_peer( @alru_cache(maxsize=1024, cache_exceptions=False) @service_timeout(COMPLETION_TIMEOUT) - async def get_contract_code(self, block_hash: Hash32, address: Address) -> bytes: + async def coro_get_contract_code(self, block_hash: Hash32, address: Address) -> bytes: """ :param block_hash: find code as of the block with block_hash :param address: which contract to look up @@ -207,7 +234,7 @@ async def get_contract_code(self, block_hash: Hash32, address: Address) -> bytes # get account for later verification, and # to confirm that our highest total difficulty peer has the info try: - account = await self.get_account(block_hash, address) + account = await self.coro_get_account(block_hash, address) except HeaderNotFound as exc: raise NoEligiblePeers("Our best peer does not have header %s" % block_hash) from exc diff --git a/trinity/utils/async_dispatch.py b/trinity/utils/async_dispatch.py new file mode 100644 index 0000000000..aa6d5d934d --- /dev/null +++ b/trinity/utils/async_dispatch.py @@ -0,0 +1,18 @@ +import asyncio +import functools +from typing import ( + Any, + Awaitable, + Callable +) + + +def async_method(method_name: str) -> Callable[..., Any]: + async def method(self: Any, *args: Any, **kwargs: Any) -> Awaitable[Any]: + loop = asyncio.get_event_loop() + + func = getattr(self, method_name) + pfunc = functools.partial(func, *args, **kwargs) + + return await loop.run_in_executor(None, pfunc) + return method diff --git a/trinity/utils/async_errors.py b/trinity/utils/async_errors.py new file mode 100644 index 0000000000..f0c24c3ffa --- /dev/null +++ b/trinity/utils/async_errors.py @@ -0,0 +1,19 @@ +from typing import ( + Awaitable, + Optional, + Tuple, + TypeVar, +) + + +TReturn = TypeVar("TReturn") + + +async def await_and_wrap_errors( + awaitable: Awaitable[TReturn]) -> Tuple[Optional[TReturn], Optional[Exception]]: + try: + val = await awaitable + except Exception as e: + return None, e + else: + return val, None diff --git a/trinity/utils/db_proxy.py b/trinity/utils/db_proxy.py new file mode 100644 index 0000000000..cb6cea95d8 --- /dev/null +++ b/trinity/utils/db_proxy.py @@ -0,0 +1,34 @@ +from multiprocessing.managers import ( + BaseManager, +) +import pathlib + +from trinity.chains import ( + AsyncHeaderChainProxy, + ChainProxy, +) +from trinity.db.chain import ChainDBProxy +from trinity.db.base import DBProxy +from trinity.db.header import ( + AsyncHeaderDBProxy +) + + +def create_db_manager(ipc_path: pathlib.Path) -> BaseManager: + """ + We're still using 'str' here on param ipc_path because an issue with + multi-processing not being able to interpret 'Path' objects correctly + """ + class DBManager(BaseManager): + pass + + # Typeshed definitions for multiprocessing.managers is incomplete, so ignore them for now: + # https://github.com/python/typeshed/blob/85a788dbcaa5e9e9a62e55f15d44530cd28ba830/stdlib/3/multiprocessing/managers.pyi#L3 + DBManager.register('get_db', proxytype=DBProxy) # type: ignore + DBManager.register('get_chaindb', proxytype=ChainDBProxy) # type: ignore + DBManager.register('get_chain', proxytype=ChainProxy) # type: ignore + DBManager.register('get_headerdb', proxytype=AsyncHeaderDBProxy) # type: ignore + DBManager.register('get_header_chain', proxytype=AsyncHeaderChainProxy) # type: ignore + + manager = DBManager(address=str(ipc_path)) # type: ignore + return manager diff --git a/trinity/utils/shutdown.py b/trinity/utils/shutdown.py new file mode 100644 index 0000000000..9911844d1d --- /dev/null +++ b/trinity/utils/shutdown.py @@ -0,0 +1,27 @@ +import asyncio +import signal + +from lahja import ( + Endpoint, +) + +from p2p.service import ( + BaseService, +) + + +async def exit_on_signal(service_to_exit: BaseService, endpoint: Endpoint = None) -> None: + loop = service_to_exit.get_event_loop() + sigint_received = asyncio.Event() + for sig in [signal.SIGINT, signal.SIGTERM]: + # TODO also support Windows + loop.add_signal_handler(sig, sigint_received.set) + + await sigint_received.wait() + try: + await service_to_exit.cancel() + if endpoint is not None: + endpoint.stop() + service_to_exit._executor.shutdown(wait=True) + finally: + loop.stop()