Skip to content

Commit

Permalink
Use PeerID exclusively to address MoE experts (#479)
Browse files Browse the repository at this point in the history
Changed declare_experts / RemoteExpert to use only p2p peer ID, not the whole multiaddress.
This slightly reduces the code complexity and gives you an easier time sharing experts with dynamic IP.

It also fixes one DHT edge case i've discovered when working on it.

Minor changes:
- fixed an edge case: previously, DHT would **freeze** if accessing DHT.peer_id or otherwise calling .run_coroutine from inside another run_coroutine
- merged RemoteExpertInfo and UidEndpoint into one structure (ExpertInfo), now in expert_uid.py
- extracted expert_uid.py from hivemind.moe.server to hivemind.moe in order to avoid circular imports
- renamed get_expert_stub into get_server_stub since it is not expert-specific


Co-authored-by: Aleksandr Borzunov <[email protected]>
Co-authored-by: Pavel Samygin <[email protected]>
Co-authored-by: Max Ryabinin <[email protected]>
  • Loading branch information
4 people authored Jun 7, 2022
1 parent c49802a commit 25366a1
Show file tree
Hide file tree
Showing 14 changed files with 130 additions and 143 deletions.
8 changes: 4 additions & 4 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import torch

from hivemind.dht import DHT
from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo
from hivemind.moe.client.expert import RemoteExpert
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertInfo
from hivemind.moe.server import ExpertBackend, Server
from hivemind.moe.server.layers import name_to_block
from hivemind.p2p import P2P, PeerInfo
from hivemind.p2p import P2P
from hivemind.utils.limits import increase_file_limit
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.tensor_descr import BatchTensorDescriptor
Expand Down Expand Up @@ -48,9 +49,8 @@ def client_process(
can_start.wait()

p2p = RemoteExpertWorker.run_coroutine(P2P.create(initial_peers=server_maddrs))
peer_info = PeerInfo(server_peer_id, server_maddrs)
experts = [
RemoteExpert(expert_info=RemoteExpertInfo(uid=f"expert.{i}", peer_info=peer_info), p2p=p2p)
RemoteExpert(expert_info=ExpertInfo(uid=f"expert.{i}", peer_id=server_peer_id), p2p=p2p)
for i in range(num_experts)
]

Expand Down
9 changes: 8 additions & 1 deletion hivemind/dht/dht.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def get(
:param kwargs: parameters forwarded to DHTNode.get_many_by_id
:returns: (value, expiration time); if value was not found, returns None
"""
assert os.getpid() != self.pid, "calling *external* DHT interface from inside DHT will result in a deadlock"
future = MPFuture()
self._outer_pipe.send(("_get", [], dict(key=key, latest=latest, future=future, **kwargs)))
return future if return_future else future.result()
Expand Down Expand Up @@ -202,6 +203,7 @@ def store(
:param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
:returns: True if store succeeds, False if it fails (due to no response or newer value)
"""
assert os.getpid() != self.pid, "calling *external* DHT interface from inside DHT will result in a deadlock"
future = MPFuture()
self._outer_pipe.send(
(
Expand Down Expand Up @@ -246,6 +248,7 @@ def run_coroutine(
or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
:note: when run_coroutine is called with return_future=False, MPFuture can be cancelled to interrupt the task.
"""
assert os.getpid() != self.pid, "calling *external* DHT interface from inside DHT will result in a deadlock"
future = MPFuture()
self._outer_pipe.send(("_run_coroutine", [], dict(coro=coro, future=future)))
return future if return_future else future.result()
Expand Down Expand Up @@ -275,7 +278,11 @@ async def _add_validators(_dht: DHT, node: DHTNode, record_validators: Iterable[
@property
def peer_id(self) -> PeerID:
if self._peer_id is None:
self._peer_id = self.run_coroutine(DHT._get_peer_id)
if os.getpid() == self.pid:
self._peer_id = self._node.peer_id
else:
# note: we cannot run_coroutine from the same pid because it would deadlock the event loop
self._peer_id = self.run_coroutine(DHT._get_peer_id)
return self._peer_id

@staticmethod
Expand Down
111 changes: 51 additions & 60 deletions hivemind/moe/client/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,21 @@
from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union

from hivemind.dht import DHT, DHTExpiration, DHTNode
from hivemind.moe.client.expert import (
RemoteExpert,
RemoteExpertInfo,
batch_create_remote_experts,
create_remote_experts,
)
from hivemind.moe.server.expert_uid import (
from hivemind.moe.client.expert import RemoteExpert, batch_create_remote_experts, create_remote_experts
from hivemind.moe.expert_uid import (
FLAT_EXPERT,
PREFIX_PATTERN,
UID_DELIMITER,
Coordinate,
ExpertInfo,
ExpertPrefix,
ExpertUID,
Score,
UidEndpoint,
is_valid_prefix,
is_valid_uid,
)
from hivemind.p2p import PeerInfo
from hivemind.utils import MPFuture, get_dht_time, get_logger
from hivemind.p2p import PeerID
from hivemind.utils import MPFuture, ValueWithExpiration, get_dht_time, get_logger

logger = get_logger(__name__)

Expand Down Expand Up @@ -100,7 +96,7 @@ def __init__(

def get_initial_beam(
self, scores: Sequence[float], beam_size: int, return_future: bool = False
) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]]:
"""
:param scores: prefer suffix coordinates that have highest scores
:param beam_size: select this many active suffixes with highest scores
Expand Down Expand Up @@ -130,9 +126,9 @@ async def _get_initial_beam(
negative_caching: bool,
cache_expiration: DHTExpiration,
num_workers: Optional[int] = None,
) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]]:
num_workers = num_workers or dht.num_workers or beam_size
beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]] = []
unattempted_indices: List[Coordinate] = sorted(
range(len(scores)), key=scores.__getitem__
) # from worst to best
Expand All @@ -150,13 +146,7 @@ async def _get_initial_beam(
try:
maybe_prefix_data = await pending_task
if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
successors = {
coord: UidEndpoint(uid=match.value[0], peer_info=PeerInfo.from_tuple(match.value[1]))
for coord, match in maybe_prefix_data.value.items()
if isinstance(coord, Coordinate)
and isinstance(getattr(match, "value", None), list)
and len(match.value) == 2
}
successors = MoEBeamSearcher._select_valid_entries(maybe_prefix_data)
if successors:
beam.append((scores[pending_best_index], pending_best_prefix, successors))
elif maybe_prefix_data is None and negative_caching:
Expand All @@ -178,7 +168,7 @@ async def _get_initial_beam(

def get_active_successors(
self, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None, return_future: bool = False
) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
) -> Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]]:
"""
:param prefixes: a list of prefix for which to find active successor uids
:param grid_size: if specified, only return successors if ther are in range [0, grid_size)
Expand All @@ -201,6 +191,22 @@ def get_active_successors(
return_future=return_future,
)

@staticmethod
def _select_valid_entries(entry: ValueWithExpiration, grid_size: Optional[int] = None):
if not isinstance(entry, ValueWithExpiration) or not isinstance(entry.value, dict):
return {}
return {
coord: ExpertInfo(uid=match.value[0], peer_id=PeerID.from_base58(match.value[1]))
for coord, match in entry.value.items()
if isinstance(coord, Coordinate)
and (grid_size is None or 0 <= coord < grid_size)
and isinstance(match, ValueWithExpiration)
and isinstance(match.value, tuple)
and len(match.value) == 2
and is_valid_uid(match.value[0])
and isinstance(match.value[1], str)
}

@staticmethod
async def _get_active_successors(
dht: DHT,
Expand All @@ -210,28 +216,18 @@ async def _get_active_successors(
negative_caching: bool,
cache_expiration: DHTExpiration,
num_workers: Optional[int] = None,
) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
) -> Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]]:
grid_size = grid_size or float("inf")
num_workers = num_workers or min(len(prefixes), dht.num_workers or len(prefixes))
dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
successors: Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]] = {}
for prefix, found in dht_responses.items():
if found and isinstance(found.value, dict):
successors[prefix] = {
coord: UidEndpoint(uid=match.value[0], peer_info=PeerInfo.from_tuple(match.value[1]))
for coord, match in found.value.items()
if isinstance(coord, Coordinate)
and 0 <= coord < grid_size
and isinstance(getattr(match, "value", None), list)
and len(match.value) == 2
}
else:
successors[prefix] = {}
if found is None and negative_caching:
logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
asyncio.create_task(
node.store(prefix, subkey=-1, value=None, expiration_time=get_dht_time() + cache_expiration)
)
successors[prefix] = MoEBeamSearcher._select_valid_entries(found, grid_size)
if not successors[prefix] and negative_caching:
logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
asyncio.create_task(
node.store(prefix, subkey=-1, value=None, expiration_time=get_dht_time() + cache_expiration)
)
return successors

def find_best_experts(
Expand All @@ -246,7 +242,6 @@ def find_best_experts(
After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
Please note that any queries that fall outside the budget will still be performed in background and cached
for subsequent iterations as long as DHTNode.cache_locally is True
:param num_workers: use up to this many concurrent workers to search DHT
:param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
:returns: a list that contains *up to* k_best RemoteExpert instances
"""
Expand All @@ -263,7 +258,6 @@ def find_best_experts(
),
return_future,
)

return create_remote_experts(result, self.dht, return_future)

@classmethod
Expand All @@ -277,23 +271,23 @@ async def _find_best_experts(
negative_caching: bool,
cache_expiration: DHTExpiration,
num_workers: Optional[int] = None,
) -> List[RemoteExpertInfo]:
) -> List[ExpertInfo]:
num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)

# form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await cls._get_initial_beam(
beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]] = await cls._get_initial_beam(
dht, node, prefix, beam_size, grid_scores[0], negative_caching, min(beam_size, num_workers)
)

best_experts_heap: List[Tuple[Score, UidEndpoint]] = [] # max-heap of expert uids/endpoints ordered by scores
best_experts_heap: List[Tuple[Score, ExpertInfo]] = [] # max-heap of expert infos ordered by scores
unique_experts: Set[ExpertUID] = set()

for dim_index in range(1, len(grid_scores) - 1):
for score, uid_endpoint in cls._iterate_matching_experts(beam, grid_scores):
if uid_endpoint.uid not in unique_experts:
for score, expert_info in cls._iterate_matching_experts(beam, grid_scores):
if expert_info.uid not in unique_experts:
push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
unique_experts.add(uid_endpoint.uid)
push_and_maybe_pop(best_experts_heap, (score, expert_info))
unique_experts.add(expert_info.uid)

# form new beam using successors from the current beam
dim_scores = grid_scores[dim_index]
Expand All @@ -306,6 +300,7 @@ async def _find_best_experts(
if isinstance(next_coord, int) and 0 <= next_coord < len(dim_scores)
),
)

_, best_uid_prefixes = zip(*best_active_pairs)

# search DHT for next step suffixes
Expand All @@ -324,22 +319,18 @@ async def _find_best_experts(
break

# add best experts from the final beam
for score, uid_endpoint in cls._iterate_matching_experts(beam, grid_scores):
if uid_endpoint.uid not in unique_experts:
for score, expert_info in cls._iterate_matching_experts(beam, grid_scores):
if expert_info.uid not in unique_experts:
push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
unique_experts.add(uid_endpoint.uid)
push_and_maybe_pop(best_experts_heap, (score, expert_info))
unique_experts.add(expert_info.uid)

best_experts = [
RemoteExpertInfo(uid_endpoint.uid, uid_endpoint.peer_info)
for _, uid_endpoint in sorted(best_experts_heap, reverse=True)
]
return best_experts
return [expert_info for _, expert_info in sorted(best_experts_heap, reverse=True)]

@staticmethod
def _iterate_matching_experts(
beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]], grid_scores: Sequence[Sequence[float]]
) -> Iterator[Tuple[Score, UidEndpoint]]:
beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]], grid_scores: Sequence[Sequence[float]]
) -> Iterator[Tuple[Score, ExpertInfo]]:
"""iterate over all exemplar experts attached to current beam"""
for score, prefix, suffixes in beam:
for next_coord, match in suffixes.items():
Expand Down Expand Up @@ -399,7 +390,7 @@ async def _batch_find_best_experts(
beam_size: int,
negative_caching: bool,
num_workers: Optional[int],
) -> Sequence[Sequence[RemoteExpertInfo]]:
) -> Sequence[Sequence[ExpertInfo]]:
batch_grid_scores = [
[tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))
]
Expand Down
33 changes: 13 additions & 20 deletions hivemind/moe/client/expert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from concurrent.futures import Future
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import torch
Expand All @@ -12,7 +11,8 @@
from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
from hivemind.dht import DHT
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import P2P, PeerInfo, StubBase
from hivemind.moe.expert_uid import ExpertInfo
from hivemind.p2p import P2P, PeerID, StubBase
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
Expand All @@ -24,16 +24,9 @@
DUMMY = torch.empty(0, requires_grad=True) # dummy tensor that triggers autograd in RemoteExpert


def get_expert_stub(p2p: P2P, server_peer_info: PeerInfo) -> "ConnectionHandlerStub":
return moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)


@dataclass(frozen=True)
class RemoteExpertInfo:
"""A simple data class containing uid of expert and server PeerInfo"""

uid: str
peer_info: PeerInfo
def get_server_stub(p2p: P2P, server_peer_id: PeerID) -> "ConnectionHandlerStub":
"""Create an RPC stub that can send requests to any expert on the specified remote server"""
return moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_id)


class RemoteExpert(nn.Module):
Expand All @@ -47,7 +40,7 @@ class RemoteExpert(nn.Module):
:param p2p: P2P instance connected to the running p2pd
"""

def __init__(self, expert_info: RemoteExpertInfo, p2p: P2P):
def __init__(self, expert_info: ExpertInfo, p2p: P2P):
super().__init__()
self._info, self.p2p = expert_info, p2p
self._rpc_info = None
Expand All @@ -57,12 +50,12 @@ def uid(self):
return self._info.uid

@property
def server_peer_info(self):
return self._info.peer_info
def peer_id(self) -> PeerID:
return self._info.peer_id

@property
def stub(self) -> StubBase:
return get_expert_stub(self.p2p, self.server_peer_info)
return get_server_stub(self.p2p, self.peer_id)

def forward(self, *args, **kwargs):
"""Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
Expand All @@ -89,10 +82,10 @@ def info(self):
return self._rpc_info

def extra_repr(self):
return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
return f"uid={self.uid}, server_peer_id={self.peer_id}"


def _create_remote_experts(infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
experts: List[Optional[RemoteExpert]] = []
for info in infos:
if info is not None:
Expand All @@ -103,7 +96,7 @@ def _create_remote_experts(infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P


def create_remote_experts(
infos: Union[Sequence[Optional[RemoteExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
) -> Union[List[Optional[RemoteExpert]], Future]:
if return_future:

Expand All @@ -118,7 +111,7 @@ async def _unpack(infos_future: MPFuture, dht: DHT):


def batch_create_remote_experts(
infos: Union[Sequence[Sequence[Optional[RemoteExpertInfo]]], MPFuture],
infos: Union[Sequence[Sequence[Optional[ExpertInfo]]], MPFuture],
dht: DHT,
return_future: bool = False,
) -> Union[List[List[Optional[RemoteExpert]]], Future]:
Expand Down
Loading

0 comments on commit 25366a1

Please sign in to comment.