Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set default DHT num_workers = 4 #342

Merged
merged 10 commits into from
Jul 31, 2021
10 changes: 5 additions & 5 deletions hivemind/dht/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from multiaddr import Multiaddr

from hivemind.dht.node import DHTNode
from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
from hivemind.p2p import P2P, PeerID
Expand All @@ -43,7 +43,7 @@ class DHT(mp.Process):
:param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
:param start: if True, automatically starts the background process on creation. Otherwise await manual start
:param daemon: if True, the background process is marked as daemon and automatically terminated after main process
:param max_workers: declare_experts and get_experts will use up to this many parallel workers
:param num_workers: declare_experts and get_experts will use up to this many parallel workers
(but no more than one per key)
:param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
:param record_validators: instances of RecordValidatorBase used for signing and validating stored records.
Expand All @@ -62,7 +62,7 @@ def __init__(
*,
start: bool,
daemon: bool = True,
max_workers: Optional[int] = None,
num_workers: int = DEFAULT_NUM_WORKERS,
record_validators: Iterable[RecordValidatorBase] = (),
shutdown_timeout: float = 3,
await_ready: bool = True,
Expand All @@ -81,7 +81,7 @@ def __init__(
raise TypeError("initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]")
self.initial_peers = initial_peers
self.kwargs = kwargs
self.max_workers = max_workers
self.num_workers = num_workers

self._record_validator = CompositeValidator(record_validators)
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
Expand All @@ -106,7 +106,7 @@ def run(self) -> None:
async def _run():
self._node = await DHTNode.create(
initial_peers=self.initial_peers,
num_workers=self.max_workers or 1,
num_workers=self.num_workers,
record_validator=self._record_validator,
**self.kwargs,
)
Expand Down
8 changes: 6 additions & 2 deletions hivemind/dht/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import dataclasses
import os
import random
from collections import Counter, defaultdict
from dataclasses import dataclass, field
Expand Down Expand Up @@ -38,6 +39,9 @@
logger = get_logger(__name__)


DEFAULT_NUM_WORKERS = int(os.getenv("HIVEMIND_DHT_NUM_WORKERS", 4))


class DHTNode:
"""
Asyncio-based class that represents one DHT participant. Created via await DHTNode.create(...)
Expand Down Expand Up @@ -110,7 +114,7 @@ async def create(
cache_refresh_before_expiry: float = 5,
cache_on_store: bool = True,
reuse_get_requests: bool = True,
num_workers: int = 1,
num_workers: int = DEFAULT_NUM_WORKERS,
chunk_size: int = 16,
blacklist_time: float = 5.0,
backoff_rate: float = 2.0,
Expand Down Expand Up @@ -154,7 +158,7 @@ async def create(
:param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
:param validate: if True, use initial peers to validate that this node is accessible and synchronized
:param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
:param client_mode: if False (default), this node will accept incoming requests as a full DHT "citzen"
:param client_mode: if False (default), this node will accept incoming requests as a full DHT "citizen"
if True, this node will refuse any incoming requests, effectively being only a client
:param record_validator: instance of RecordValidatorBase used for signing and validating stored records
:param authorizer: instance of AuthorizerBase used for signing and validating requests and response
Expand Down
6 changes: 3 additions & 3 deletions hivemind/moe/client/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def _get_initial_beam(
cache_expiration: DHTExpiration,
num_workers: Optional[int] = None,
) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
num_workers = num_workers or dht.max_workers or beam_size
num_workers = num_workers or dht.num_workers or beam_size
beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
unattempted_indices: List[Coordinate] = sorted(
range(len(scores)), key=scores.__getitem__
Expand Down Expand Up @@ -206,7 +206,7 @@ async def _get_active_successors(
num_workers: Optional[int] = None,
) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
grid_size = grid_size or float("inf")
num_workers = num_workers or min(len(prefixes), dht.max_workers or len(prefixes))
num_workers = num_workers or min(len(prefixes), dht.num_workers or len(prefixes))
dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
for prefix, found in dht_responses.items():
Expand Down Expand Up @@ -270,7 +270,7 @@ async def _find_best_experts(
cache_expiration: DHTExpiration,
num_workers: Optional[int] = None,
) -> List[RemoteExpert]:
num_workers = num_workers or min(beam_size, dht.max_workers or beam_size)
num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)

# form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await cls._get_initial_beam(
Expand Down
4 changes: 2 additions & 2 deletions hivemind/moe/server/dht_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def declare_experts(
async def _declare_experts(
dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration
) -> Dict[ExpertUID, bool]:
num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
expiration_time = get_dht_time() + expiration
data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
for uid in uids:
Expand Down Expand Up @@ -89,7 +89,7 @@ async def _get_experts(
) -> List[Optional[RemoteExpert]]:
if expiration_time is None:
expiration_time = get_dht_time()
num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)

experts: List[Optional[RemoteExpert]] = [None] * len(uids)
Expand Down