Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up P2P client creation #343

Merged
merged 17 commits into from
Aug 3, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hivemind/p2p/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PHandlerError
from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PDaemonError, P2PHandlerError
from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
from hivemind.p2p.servicer import ServicerBase, StubBase
88 changes: 40 additions & 48 deletions hivemind/p2p/p2p_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from contextlib import closing, suppress
from dataclasses import dataclass
from importlib.resources import path
from subprocess import Popen
from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Sequence, Tuple, TypeVar, Union

from multiaddr import Multiaddr
Expand Down Expand Up @@ -68,8 +67,8 @@ def __init__(self):
self.peer_id = None
self._child = None
self._alive = False
self._reader_task = None
self._listen_task = None
self._server_stopped = asyncio.Event()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This event is redundant, we can just cancel self._listen_task.


@classmethod
async def create(
Expand All @@ -90,9 +89,7 @@ async def create(
use_relay_discovery: bool = False,
use_auto_relay: bool = False,
relay_hop_limit: int = 0,
quiet: bool = True,
ping_n_attempts: int = 5,
ping_delay: float = 0.4,
startup_timeout: float = 15,
) -> "P2P":
"""
Start a new p2pd process and connect to it.
Expand All @@ -113,10 +110,7 @@ async def create(
:param use_relay_discovery: enables passive discovery for relay
:param use_auto_relay: enables autorelay
:param relay_hop_limit: sets the hop limit for hop relays
:param quiet: make the daemon process quiet
:param ping_n_attempts: try to ping the daemon with this number of attempts after starting it
:param ping_delay: wait for ``ping_delay * (2 ** (k - 1))`` seconds before the k-th attempt to ping the daemon
(in particular, wait for ``ping_delay`` seconds before the first attempt)
:param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
:return: a wrapper for the p2p daemon
"""

Expand Down Expand Up @@ -157,37 +151,26 @@ async def create(
autoRelay=use_auto_relay,
relayHopLimit=relay_hop_limit,
b=need_bootstrap,
q=quiet,
**process_kwargs,
)

self._child = Popen(args=proc_args, encoding="utf8")
self._child = await asyncio.subprocess.create_subprocess_exec(
*proc_args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
)
self._alive = True
self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)

await self._ping_daemon_with_retries(ping_n_attempts, ping_delay)
ready = asyncio.Future()
self._reader_task = asyncio.create_task(self._read_outputs(ready))
try:
await asyncio.wait_for(ready, startup_timeout)
except asyncio.TimeoutError:
await self.shutdown()
raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")

self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
await self._ping_daemon()
return self

async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
for try_number in range(ping_n_attempts):
await asyncio.sleep(ping_delay * (2 ** try_number))

if self._child.poll() is not None: # Process died
break

try:
await self._ping_daemon()
break
except Exception as e:
if try_number == ping_n_attempts - 1:
logger.exception("Failed to ping p2pd that has just started")
await self.shutdown()
raise

if self._child.returncode is not None:
raise RuntimeError(f"The p2p daemon has died with return code {self._child.returncode}")

@classmethod
async def replicate(cls, daemon_listen_maddr: Multiaddr) -> "P2P":
"""
Expand Down Expand Up @@ -437,20 +420,10 @@ def iterate_protobuf_handler(
def _start_listening(self) -> None:
async def listen() -> None:
async with self._client.listen():
await self._server_stopped.wait()
await asyncio.Future() # Wait until this task will be cancelled in _terminate()

self._listen_task = asyncio.create_task(listen())

async def _stop_listening(self) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cancelling self._listen_task is enough here. The corresponding code moved to self._terminate() (since it is not async anymore).

if self._listen_task is not None:
self._server_stopped.set()
self._listen_task.cancel()
try:
await self._listen_task
except asyncio.CancelledError:
self._listen_task = None
self._server_stopped.clear()

async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
if self._listen_task is None:
self._start_listening()
Expand All @@ -469,14 +442,19 @@ def is_alive(self) -> bool:
return self._alive

async def shutdown(self) -> None:
await self._stop_listening()
await asyncio.get_event_loop().run_in_executor(None, self._terminate)
self._terminate()
if self._child is not None:
await self._child.wait()

def _terminate(self) -> None:
if self._listen_task is not None:
self._listen_task.cancel()
if self._reader_task is not None:
self._reader_task.cancel()

self._alive = False
if self._child is not None and self._child.poll() is None:
if self._child is not None and self._child.returncode is None:
self._child.terminate()
self._child.wait()
logger.debug(f"Terminated p2pd with id = {self.peer_id}")

with suppress(FileNotFoundError):
Expand Down Expand Up @@ -504,8 +482,22 @@ def _convert_process_arg_type(val: Any) -> Any:
def _maddrs_to_str(maddrs: List[Multiaddr]) -> str:
return ",".join(str(addr) for addr in maddrs)

async def _read_outputs(self, ready: asyncio.Future) -> None:
last_line = None
while True:
line = await self._child.stdout.readline()
if not line: # Stream closed
break
last_line = line.rstrip().decode(errors="ignore")

if last_line.startswith("Peer ID:"):
ready.set_result(None)

if not ready.done():
ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}"))


class P2PInterruptedError(Exception):
class P2PDaemonError(RuntimeError):
pass


Expand Down
11 changes: 10 additions & 1 deletion tests/test_p2p_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
from multiaddr import Multiaddr

from hivemind.p2p import P2P, P2PHandlerError
from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
from hivemind.proto import dht_pb2
from hivemind.utils.serializer import MSGPackSerializer

Expand All @@ -33,6 +33,15 @@ async def test_daemon_killed_on_del():
assert not is_process_running(child_pid)


@pytest.mark.asyncio
async def test_startup_error_message():
with pytest.raises(P2PDaemonError, match=r"Failed to connect to bootstrap peers"):
await P2P.create(initial_peers=["/ip4/127.0.0.1/tcp/666/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"])

with pytest.raises(P2PDaemonError, match=r"Daemon failed to start in .+ seconds"):
await P2P.create(startup_timeout=0.1) # Test that startup_timeout works


@pytest.mark.parametrize(
"host_maddrs",
[
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/dht_swarms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue, **kwargs):
asyncio.set_event_loop(asyncio.new_event_loop())
loop = asyncio.get_event_loop()

node = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, ping_n_attempts=10, **kwargs))
node = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, **kwargs))
maddrs = loop.run_until_complete(node.get_visible_maddrs())

info_queue.put((node.node_id, node.peer_id, maddrs))
Expand Down