Skip to content

Commit

Permalink
Disentangle DecentralizedAverager components, add weights (#217)
Browse files Browse the repository at this point in the history
This PR changes the internal logic of DecentralizedAverager to make matchmaking code independent of allreduce and vice versa.

- Matchmaking now returns GroupInfo (before: it returned AllreduceRunner)
- Matchmaking no longer stores AllReduce parameters
- Matchmaking no longer owns averaged_tensors
- Matchmaking no longer handles load balancing
- Removed group_key_seed (duplicate of group_id)
- throughput and client_mode is now allgathered via data_for_gather
- AllReduceRunner now accepts optional peer-wise weights
- Added test for weighted averaging
- Fixed a minor bug: when encountering an internal error, averager attempts to warn its groupmates. Previously, it would send warning to peers even if these peers can't accept incoming requests. This caused fabulous error messages.
- load_balance_peers is now ran in executor to avoid latency issues

Co-authored-by: Max Ryabinin <[email protected]>
  • Loading branch information
justheuristic and mryab authored Apr 10, 2021
1 parent ca6d87a commit 053c7c7
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 125 deletions.
2 changes: 1 addition & 1 deletion docs/modules/client.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
.. autoclass:: DecentralizedAverager
:members:
:member-order: bysource
:exclude-members: get_tensors, update_tensors, rpc_join_group, rpc_aggregate_part
:exclude-members: get_tensors, get_tensors_async, update_tensors, rpc_join_group, rpc_aggregate_part
97 changes: 70 additions & 27 deletions hivemind/client/averaging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,26 @@
import uuid
import weakref
from concurrent.futures.thread import ThreadPoolExecutor
from dataclasses import asdict
from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator

import grpc
from grpc._cython.cygrpc import InternalError
import torch
import numpy as np

import hivemind
from hivemind.dht import DHT, DHTID
from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts
from hivemind.client.averaging.load_balancing import load_balance_peers
from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
from hivemind.client.averaging.group_info import GroupInfo
from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, \
serialize_torch_tensor, deserialize_torch_tensor, split_for_streaming, combine_from_streaming
from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
from hivemind.utils import Endpoint, Port, MPFuture, get_logger
from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescriptor

# flavour types
StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
Expand Down Expand Up @@ -85,7 +88,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
_pending_group_assembled: asyncio.Event
serializer = MSGPackSerializer

def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *, start: bool,
def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: bool,
prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
averaging_expiration: float = 15, request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
Expand All @@ -112,12 +115,15 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.D
for tensor in self._averaged_tensors:
assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
tensor.share_memory_()
self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
self.schema_hash = compute_schema_hash(self._averaged_tensors)
self._throughput = throughput

self.matchmaking_kwargs = dict(
prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout,
chunk_size_bytes=chunk_size_bytes, compression_type=compression_type,
throughput=throughput, min_vector_size=min_vector_size)
min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout)
self.allreduce_kwargs = dict(compression_type=compression_type, chunk_size_bytes=chunk_size_bytes,
min_vector_size=min_vector_size)
self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce

Expand Down Expand Up @@ -170,8 +176,8 @@ async def _run():
else:
logger.info(f"The averager running in an experimental client mode, please report any bugs.")

self._matchmaking = Matchmaking(self.endpoint, self._averaged_tensors, self.dht, **self.matchmaking_kwargs,
client_mode=not self.listen, return_deltas=True)
self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs,
client_mode=not self.listen)
if self.listen:
asyncio.create_task(self._declare_for_download_periodically())

Expand Down Expand Up @@ -207,55 +213,58 @@ def __del__(self):
if self._parent_pid != os.getpid() or self.is_alive():
self.shutdown()

def step(self, gather: Optional[DataForGather] = None, allow_retries: bool = True, timeout: Optional[float] = None,
wait=True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
def step(self, gather: Optional[DataForGather] = None, weight: float = 1.0, timeout: Optional[float] = None,
allow_retries: bool = True, wait: bool = True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
"""
Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
:param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
within the specified timeout
:param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
(this operation is known as all-gather). The gathered data will be available as the output of this function.
:param weight: averaging weight for this peer, int or float, must be strictly positive
:param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
within the specified timeout
:param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
:param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
:returns: on success, update averaged_tensors and return group info; on failure, return None
"""
assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
future, _future = MPFuture.make_pair()
gather_binary = self.serializer.dumps(gather) # serialize here to avoid loading modules in the averager process
self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary,
self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
allow_retries=allow_retries, timeout=timeout)))
return future.result() if wait else future

async def _step(self, *, future: MPFuture, gather_binary: bytes, allow_retries: bool, timeout: Optional[float]):
async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
allow_retries: bool, timeout: Optional[float]):
loop = asyncio.get_event_loop()
start_time = get_dht_time()
group_id = None

while not future.done():
try:
self._pending_group_assembled.clear()
allreduce_group = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=gather_binary)
if allreduce_group is None:
data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
if group_info is None:
raise AllreduceException("Averaging step failed: could not find a group.")

group_id = allreduce_group.group_id
self._running_groups[group_id] = allreduce_group
group_id = group_info.group_id
allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
self._running_groups[group_id] = allreduce_runner
self._pending_group_assembled.set()
await asyncio.wait_for(allreduce_group.run(), self._allreduce_timeout)
await loop.run_in_executor(None, self.update_tensors, allreduce_group)
await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
await loop.run_in_executor(None, self.update_tensors, allreduce_runner)

# averaging is finished, exit the loop
gathered_items = map(self.serializer.loads, allreduce_group.gathered)
gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
future.set_result(gathered_data_by_peer)
future.set_result(allreduce_runner.gathered)

except (AllreduceException, MatchmakingException, asyncio.InvalidStateError,
grpc.RpcError, grpc.aio.AioRpcError, InternalError) as e:
except (AllreduceException, MatchmakingException, AssertionError,
asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError, InternalError) as e:
time_elapsed = get_dht_time() - start_time
if not allow_retries or (timeout is not None and timeout < time_elapsed):
logger.warning(f"Averager caught {e}")
future.set_result(None)
else:
logger.debug(f"caught {e}, retrying")
logger.warning(f"Averager caught {e}, retrying")

except Exception as e:
future.set_exception(e)
Expand All @@ -264,6 +273,23 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, allow_retries:
_ = self._running_groups.pop(group_id, None)
self._pending_group_assembled.set()

async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
""" Use a group description found by Matchmaking to form AllreduceRunner """
try:
weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))

# compute optimal part sizes from peer throughputs
incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]
part_sizes = await asyncio.get_event_loop().run_in_executor(
None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
async with self.get_tensors_async() as averaged_tensors:
return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint,
ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes,
weights=weights, gathered=user_gathered, return_deltas=True, **kwargs)
except Exception as e:
raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {group_info}")

def update_tensors(self, allreduce_group: AllReduceRunner):
"""
a private (extendable) method that applies changes from a finished allreduce to local tensors
Expand All @@ -288,6 +314,15 @@ def get_tensors(self) -> Sequence[torch.Tensor]:
yield self._averaged_tensors
self.last_updated = get_dht_time()

@contextlib.asynccontextmanager
async def get_tensors_async(self) -> Sequence[torch.Tensor]:
""" Like get_tensors, but uses an asynchronous contextmanager """
try:
await asyncio.get_event_loop().run_in_executor(None, self.lock_averaged_tensors.acquire)
yield self._averaged_tensors
finally:
self.lock_averaged_tensors.release()

async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
""" accept or reject a join request from another averager; if accepted, run him through allreduce steps """
Expand Down Expand Up @@ -478,3 +513,11 @@ def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp.
future.set_exception(e)
logger.warning(e)
continue


def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
""" A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values """
schema_dicts = [{field_name: str(field_value)
for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
for tensor in tensors]
return DHTID.generate(source=schema_dicts).to_bytes()
Loading

0 comments on commit 053c7c7

Please sign in to comment.