Skip to content

Commit

Permalink
Speed up DHT swarm creation
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Jul 16, 2021
1 parent 36282f8 commit 2ae476f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
4 changes: 3 additions & 1 deletion hivemind/dht/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class DHT(mp.Process):
The validators will be combined using the CompositeValidator class. It merges them when possible
(according to their `.merge_with()` policies) and orders them according to the `.priority` properties.
:param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
:param await_ready: if True, the constructor waits until the DHT process is ready to process incoming requests
:param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
"""

Expand All @@ -64,6 +65,7 @@ def __init__(
max_workers: Optional[int] = None,
record_validators: Iterable[RecordValidatorBase] = (),
shutdown_timeout: float = 3,
await_ready: bool = True,
**kwargs,
):
self._parent_pid = os.getpid()
Expand Down Expand Up @@ -91,7 +93,7 @@ def __init__(
self._p2p_replica = None

if start:
self.run_in_background(await_ready=True)
self.run_in_background(await_ready=await_ready)

def run(self) -> None:
"""Serve DHT forever. This function will not return until DHT node is shut down"""
Expand Down
5 changes: 3 additions & 2 deletions tests/test_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,9 @@ def test_allgather():
futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))

gathered_data = [future.result() for future in futures]
gathered_data_reprs = [repr(sorted({peer_id.to_base58(): data for peer_id, data in result.items()}))
for result in gathered_data]
gathered_data_reprs = [
repr(sorted({peer_id.to_base58(): data for peer_id, data in result.items()})) for result in gathered_data
]
assert len(set(gathered_data_reprs)) == 2

reference_metadata = {
Expand Down
12 changes: 7 additions & 5 deletions tests/test_utils/dht_swarms.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,11 @@ async def launch_star_shaped_swarm(n_peers: int, **kwargs) -> List[DHTNode]:


def launch_dht_instances(n_peers: int, **kwargs) -> List[DHT]:
# TODO: Do it in parallel
dhts = [DHT(start=True, **kwargs)]
initial_peers = dhts[0].get_visible_maddrs()

instances = [DHT(start=True, **kwargs)]
initial_peers = instances[0].get_visible_maddrs()
instances.extend(DHT(initial_peers=initial_peers, start=True, **kwargs) for _ in range(n_peers - 1))
return instances
dhts.extend(DHT(initial_peers=initial_peers, start=True, await_ready=False, **kwargs) for _ in range(n_peers - 1))
for instance in dhts[1:]:
instance.ready.wait()

return dhts

0 comments on commit 2ae476f

Please sign in to comment.