From 053c7c7d131871a3d7b5bdc28e974d5a47e9686b Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 10 Apr 2021 18:38:47 +0300 Subject: [PATCH] Disentangle DecentralizedAverager components, add weights (#217) This PR changes the internal logic of DecentralizedAverager to make matchmaking code independent of allreduce and vice versa. - Matchmaking now returns GroupInfo (before: it returned AllreduceRunner) - Matchmaking no longer stores AllReduce parameters - Matchmaking no longer owns averaged_tensors - Matchmaking no longer handles load balancing - Removed group_key_seed (duplicate of group_id) - throughput and client_mode is now allgathered via data_for_gather - AllReduceRunner now accepts optional peer-wise weights - Added test for weighted averaging - Fixed a minor bug: when encountering an internal error, averager attempts to warn its groupmates. Previously, it would send warning to peers even if these peers can't accept incoming requests. This caused fabulous error messages. - load_balance_peers is now ran in executor to avoid latency issues Co-authored-by: Max Ryabinin --- docs/modules/client.rst | 2 +- hivemind/client/averaging/__init__.py | 97 ++++++++++++++------ hivemind/client/averaging/allreduce.py | 32 ++++--- hivemind/client/averaging/group_info.py | 19 ++++ hivemind/client/averaging/key_manager.py | 12 +-- hivemind/client/averaging/matchmaking.py | 108 ++++++++--------------- hivemind/proto/averaging.proto | 8 +- tests/test_averaging.py | 38 +++++++- 8 files changed, 191 insertions(+), 125 deletions(-) create mode 100644 hivemind/client/averaging/group_info.py diff --git a/docs/modules/client.rst b/docs/modules/client.rst index 65cf84480..e7ae42b90 100644 --- a/docs/modules/client.rst +++ b/docs/modules/client.rst @@ -21,4 +21,4 @@ .. autoclass:: DecentralizedAverager :members: :member-order: bysource - :exclude-members: get_tensors, update_tensors, rpc_join_group, rpc_aggregate_part + :exclude-members: get_tensors, get_tensors_async, update_tensors, rpc_join_group, rpc_aggregate_part diff --git a/hivemind/client/averaging/__init__.py b/hivemind/client/averaging/__init__.py index 1a69344a8..ab2f09974 100644 --- a/hivemind/client/averaging/__init__.py +++ b/hivemind/client/averaging/__init__.py @@ -11,6 +11,7 @@ import uuid import weakref from concurrent.futures.thread import ThreadPoolExecutor +from dataclasses import asdict from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator import grpc @@ -18,16 +19,18 @@ import torch import numpy as np -import hivemind +from hivemind.dht import DHT, DHTID from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts +from hivemind.client.averaging.load_balancing import load_balance_peers from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException +from hivemind.client.averaging.group_info import GroupInfo from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2 from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, \ serialize_torch_tensor, deserialize_torch_tensor, split_for_streaming, combine_from_streaming from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration from hivemind.utils.serializer import MSGPackSerializer, SerializerBase -from hivemind.utils import Endpoint, Port, MPFuture, get_logger +from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescriptor # flavour types StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader] @@ -85,7 +88,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin _pending_group_assembled: asyncio.Event serializer = MSGPackSerializer - def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *, start: bool, + def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: bool, prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None, averaging_expiration: float = 15, request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16, allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0, @@ -112,12 +115,15 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.D for tensor in self._averaged_tensors: assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors" tensor.share_memory_() + self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors)) + self.schema_hash = compute_schema_hash(self._averaged_tensors) + self._throughput = throughput self.matchmaking_kwargs = dict( prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size, - min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout, - chunk_size_bytes=chunk_size_bytes, compression_type=compression_type, - throughput=throughput, min_vector_size=min_vector_size) + min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout) + self.allreduce_kwargs = dict(compression_type=compression_type, chunk_size_bytes=chunk_size_bytes, + min_vector_size=min_vector_size) self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce @@ -170,8 +176,8 @@ async def _run(): else: logger.info(f"The averager running in an experimental client mode, please report any bugs.") - self._matchmaking = Matchmaking(self.endpoint, self._averaged_tensors, self.dht, **self.matchmaking_kwargs, - client_mode=not self.listen, return_deltas=True) + self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs, + client_mode=not self.listen) if self.listen: asyncio.create_task(self._declare_for_download_periodically()) @@ -207,26 +213,29 @@ def __del__(self): if self._parent_pid != os.getpid() or self.is_alive(): self.shutdown() - def step(self, gather: Optional[DataForGather] = None, allow_retries: bool = True, timeout: Optional[float] = None, - wait=True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]: + def step(self, gather: Optional[DataForGather] = None, weight: float = 1.0, timeout: Optional[float] = None, + allow_retries: bool = True, wait: bool = True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]: """ Set up the averager to look for a group and run one round of averaging, return True on success, False on failure - :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again - within the specified timeout :param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate (this operation is known as all-gather). The gathered data will be available as the output of this function. + :param weight: averaging weight for this peer, int or float, must be strictly positive + :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again + within the specified timeout :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background. :returns: on success, update averaged_tensors and return group info; on failure, return None """ + assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}" future, _future = MPFuture.make_pair() gather_binary = self.serializer.dumps(gather) # serialize here to avoid loading modules in the averager process - self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, + self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight, allow_retries=allow_retries, timeout=timeout))) return future.result() if wait else future - async def _step(self, *, future: MPFuture, gather_binary: bytes, allow_retries: bool, timeout: Optional[float]): + async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float, + allow_retries: bool, timeout: Optional[float]): loop = asyncio.get_event_loop() start_time = get_dht_time() group_id = None @@ -234,28 +243,28 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, allow_retries: while not future.done(): try: self._pending_group_assembled.clear() - allreduce_group = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=gather_binary) - if allreduce_group is None: + data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary]) + group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather) + if group_info is None: raise AllreduceException("Averaging step failed: could not find a group.") - - group_id = allreduce_group.group_id - self._running_groups[group_id] = allreduce_group + group_id = group_info.group_id + allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs) + self._running_groups[group_id] = allreduce_runner self._pending_group_assembled.set() - await asyncio.wait_for(allreduce_group.run(), self._allreduce_timeout) - await loop.run_in_executor(None, self.update_tensors, allreduce_group) + await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout) + await loop.run_in_executor(None, self.update_tensors, allreduce_runner) # averaging is finished, exit the loop - gathered_items = map(self.serializer.loads, allreduce_group.gathered) - gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items)) - future.set_result(gathered_data_by_peer) + future.set_result(allreduce_runner.gathered) - except (AllreduceException, MatchmakingException, asyncio.InvalidStateError, - grpc.RpcError, grpc.aio.AioRpcError, InternalError) as e: + except (AllreduceException, MatchmakingException, AssertionError, + asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError, InternalError) as e: time_elapsed = get_dht_time() - start_time if not allow_retries or (timeout is not None and timeout < time_elapsed): + logger.warning(f"Averager caught {e}") future.set_result(None) else: - logger.debug(f"caught {e}, retrying") + logger.warning(f"Averager caught {e}, retrying") except Exception as e: future.set_exception(e) @@ -264,6 +273,23 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, allow_retries: _ = self._running_groups.pop(group_id, None) self._pending_group_assembled.set() + async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner: + """ Use a group description found by Matchmaking to form AllreduceRunner """ + try: + weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered)) + user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered))) + + # compute optimal part sizes from peer throughputs + incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)] + part_sizes = await asyncio.get_event_loop().run_in_executor( + None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size) + async with self.get_tensors_async() as averaged_tensors: + return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint, + ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes, + weights=weights, gathered=user_gathered, return_deltas=True, **kwargs) + except Exception as e: + raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {group_info}") + def update_tensors(self, allreduce_group: AllReduceRunner): """ a private (extendable) method that applies changes from a finished allreduce to local tensors @@ -288,6 +314,15 @@ def get_tensors(self) -> Sequence[torch.Tensor]: yield self._averaged_tensors self.last_updated = get_dht_time() + @contextlib.asynccontextmanager + async def get_tensors_async(self) -> Sequence[torch.Tensor]: + """ Like get_tensors, but uses an asynchronous contextmanager """ + try: + await asyncio.get_event_loop().run_in_executor(None, self.lock_averaged_tensors.acquire) + yield self._averaged_tensors + finally: + self.lock_averaged_tensors.release() + async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext ) -> AsyncIterator[averaging_pb2.MessageFromLeader]: """ accept or reject a join request from another averager; if accepted, run him through allreduce steps """ @@ -478,3 +513,11 @@ def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp. future.set_exception(e) logger.warning(e) continue + + +def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes: + """ A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values """ + schema_dicts = [{field_name: str(field_value) + for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()} + for tensor in tensors] + return DHTID.generate(source=schema_dicts).to_bytes() diff --git a/hivemind/client/averaging/allreduce.py b/hivemind/client/averaging/allreduce.py index 74f920781..8ff1f004e 100644 --- a/hivemind/client/averaging/allreduce.py +++ b/hivemind/client/averaging/allreduce.py @@ -30,13 +30,15 @@ def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoi assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group" self.group_id, self.endpoint = group_id, endpoint self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes - self.client_mode_endpoints = {endpoint for endpoint, size in zip(self.ordered_group_endpoints, part_sizes) if size == 0} + self.client_mode_endpoints = {endpoint for endpoint, part_size in zip(self.ordered_group_endpoints, part_sizes) + if part_size == 0} self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes))) self.tensor_shapes = tuple(tensor.shape for tensor in tensors) self.return_deltas = return_deltas - self.accumulator = self.local_tensor_parts[self.endpoint].clone() # sum inputs from peers to this tensor - self.accumulated_from: Set[Endpoint] = {self.endpoint} # peers that we have accumulated our part from + self.accumulator = torch.zeros_like(self.local_tensor_parts[self.endpoint]) + self.denominator = 0.0 # number of peers added to accumulator or sum of their weights + self.accumulated_from: Set[Endpoint] = set() # peers that we have accumulated our part from self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future() # will be set to [accumulator / group size] self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {} # averaged chunks from all peers will be put here self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future() # final result or exception @@ -56,21 +58,23 @@ def __contains__(self, endpoint: Endpoint): def group_size(self): return len(self.ordered_group_endpoints) - async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor) -> torch.Tensor: + async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor, weight: float = 1.0) -> torch.Tensor: """ Add vector part to accumulator, wait for all other vectors to be added, then return the average part """ assert not self.averaged_part.done(), f"already finished averaging part: {self.averaged_part}" assert not self.future.done(), f"already finished allreduce: {self.future}" assert source in self.local_tensor_parts, "unexpected source, not a part of current group" assert source not in self.accumulated_from, "duplicate source, already received that part" assert not self.endpoint in self.client_mode_endpoints, f"{self.endpoint} is in client mode" + assert isinstance(weight, (int, float)) and weight > 0, "averaging weights must be a non-negative int/float" logger.debug(f"{self} - accumulating tensor part from {source}") - self.accumulator.add_(remote_part) + self.accumulator.add_(remote_part, alpha=weight) + self.denominator += weight self.accumulated_from.add(source) assert len(self.accumulated_from) <= self.group_size if len(self.accumulated_from) == len(self.local_tensor_parts): - average_result = self.accumulator.div_(len(self.accumulated_from)) + average_result = self.accumulator.div_(self.denominator) self.register_averaged_part(self.endpoint, average_result) self.averaged_part.set_result(average_result) @@ -127,19 +131,21 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint, ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType, - chunk_size_bytes: int, part_sizes: Tuple[int, ...], group_key_seed: int, gathered: Sequence[Any] = (), - return_deltas: bool = False): + chunk_size_bytes: int, part_sizes: Tuple[int, ...], weights: Tuple[float, ...], + gathered: Dict[Endpoint, Any], return_deltas: bool = False): super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes, ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas) self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered + self.peer_weights = dict(zip(self.ordered_group_endpoints, weights)) self.averaged_part_stream: asyncio.Future[Tuple[runtime_pb2.Tensor, ...]] = asyncio.Future() - self.group_key_seed = group_key_seed def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub: return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True) async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor: """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """ + if peer_endpoint == self.endpoint: + return await self.accumulate_part(self.endpoint, local_part, weight=self.peer_weights[self.endpoint]) serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False) chunks = split_for_streaming(serialized_tensor_part, self.chunk_size_bytes) @@ -178,14 +184,14 @@ async def run(self) -> Sequence[torch.Tensor]: try: await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer]) for i, peer in enumerate(self.ordered_group_endpoints) - if peer != self.endpoint and self.part_sizes[i] > 0)) + if peer not in self.client_mode_endpoints)) return await self except BaseException as e: code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}") self.set_exception(e) - for peer_endpoint in self.ordered_group_endpoints: - if peer_endpoint != self.endpoint: + for peer_endpoint, part_size in zip(self.ordered_group_endpoints, self.part_sizes): + if peer_endpoint != self.endpoint and part_size > 0: asyncio.create_task(self._send_error_to_peer(peer_endpoint, code)) raise @@ -197,7 +203,7 @@ async def accumulate_part_streaming(self, source: Endpoint, stream_messages: Ite except RuntimeError as e: raise AllreduceException(f"Could not deserialize tensor part from {source} for streaming {e}") - averaged_part = await self.accumulate_part(source, tensor_part) + averaged_part = await self.accumulate_part(source, tensor_part, weight=self.peer_weights[source]) if not self.averaged_part_stream.done(): serialized_tensor = serialize_torch_tensor(averaged_part, self.compression_type, allow_inplace=False) stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes)) diff --git a/hivemind/client/averaging/group_info.py b/hivemind/client/averaging/group_info.py new file mode 100644 index 000000000..de36a4935 --- /dev/null +++ b/hivemind/client/averaging/group_info.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from typing import Tuple + +from hivemind.utils import Endpoint + + +@dataclass(frozen=True) +class GroupInfo: + """ A group of peers assembled through decentralized matchmaking """ + group_id: bytes # random unique bytestring that describes the current group, generated by group leader + endpoints: Tuple[Endpoint, ...] # an ordered sequence of endpoints of each groupmate + gathered: Tuple[bytes, ...] # binary metadata gathered from all peers by leader, same order as endpoints + + @property + def group_size(self): + return len(self.endpoints) + + def __contains__(self, endpoint: Endpoint): + return endpoint in self.endpoints diff --git a/hivemind/client/averaging/key_manager.py b/hivemind/client/averaging/key_manager.py index 6d48a2e48..5ad5a3562 100644 --- a/hivemind/client/averaging/key_manager.py +++ b/hivemind/client/averaging/key_manager.py @@ -6,7 +6,7 @@ import numpy as np from hivemind.dht import DHT -from hivemind.client.averaging.allreduce import AllReduceRunner +from hivemind.client.averaging.group_info import GroupInfo from hivemind.utils import get_logger, Endpoint, DHTExpiration, get_dht_time, ValueWithExpiration GroupKey = str @@ -103,17 +103,17 @@ def get_suggested_nbits(cls, search_result: Optional[ValueWithExpiration]) -> Op else: return None - async def update_key_on_group_assembled(self, allreduce_group: AllReduceRunner, is_leader: bool = True): + async def update_key_on_group_assembled(self, group_info: GroupInfo, is_leader: bool = True): """ this function is triggered every time an averager finds an allreduce group """ - rng = random.Random(allreduce_group.group_key_seed) - index = allreduce_group.ordered_group_endpoints.index(self.endpoint) - generalized_index = rng.sample(range(self.target_group_size), allreduce_group.group_size)[index] + rng = random.Random(group_info.group_id) + index = group_info.endpoints.index(self.endpoint) + generalized_index = rng.sample(range(self.target_group_size), group_info.group_size)[index] nbits = int(np.ceil(np.log2(self.target_group_size))) new_bits = bin(generalized_index)[2:].rjust(nbits, '0') self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits):] if self.group_bits else '' logger.debug(f"{self.endpoint} - updated group key to {self.group_bits}") - if is_leader and self.insufficient_size < allreduce_group.group_size < self.excessive_size: + if is_leader and self.insufficient_size < group_info.group_size < self.excessive_size: asyncio.create_task(self.notify_stragglers()) if self.suggested_nbits is not None and self.suggested_nbits != len(self.group_bits): num_extra_bits = max(0, self.suggested_nbits - len(self.group_bits)) diff --git a/hivemind/client/averaging/matchmaking.py b/hivemind/client/averaging/matchmaking.py index 39710b299..de20ebc02 100644 --- a/hivemind/client/averaging/matchmaking.py +++ b/hivemind/client/averaging/matchmaking.py @@ -4,20 +4,17 @@ import contextlib import random -from dataclasses import asdict from math import isfinite -from typing import Sequence, Optional, AsyncIterator, Set, Tuple, Dict +from typing import Optional, AsyncIterator, Set, Tuple, Dict import concurrent.futures import asyncio import grpc -import torch -from hivemind.client.averaging.allreduce import AllReduceRunner -from hivemind.client.averaging.load_balancing import load_balance_peers +from hivemind.client.averaging.group_info import GroupInfo from hivemind.client.averaging.key_manager import GroupKeyManager, GroupKey from hivemind.dht import DHT, DHTID, DHTExpiration, get_dht_time -from hivemind.utils import get_logger, Endpoint, TensorDescriptor, timed_storage, TimedStorage +from hivemind.utils import get_logger, Endpoint, timed_storage, TimedStorage from hivemind.proto import averaging_pb2, averaging_pb2_grpc from hivemind.utils.grpc import ChannelCache @@ -38,24 +35,21 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer): Hence, instead of accounting for such deadlocks, we simply break them with request_timeout. """ - def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, - prefix: str, target_group_size: int, min_group_size: int, min_vector_size: int, + def __init__(self, endpoint: Endpoint, schema_hash: bytes, dht: DHT, *, + prefix: str, target_group_size: int, min_group_size: int, request_timeout: float, client_mode: bool, initial_group_bits: Optional[str] = None, - averaging_expiration: float = 15, throughput: Optional[float] = None, **allreduce_kwargs): + averaging_expiration: float = 15): assert '.' not in prefix, "group prefix must be a string without ." if request_timeout is None or request_timeout >= averaging_expiration: logger.warning("It is recommended to use request_timeout smaller than averaging_expiration. Otherwise," "matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring.") super().__init__() - self.endpoint, self.averaged_tensors = endpoint, tuple(averaged_tensors) + self.endpoint, self.schema_hash = endpoint, schema_hash self.group_key_manager = GroupKeyManager(dht, endpoint, prefix, initial_group_bits, target_group_size) self.target_group_size, self.min_group_size = target_group_size, min_group_size self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout self.client_mode = client_mode - self.throughput, self.min_vector_size, self.allreduce_kwargs = throughput, min_vector_size, allreduce_kwargs - self.schema_hash = compute_schema_hash(self.averaged_tensors) - self.total_size = sum(tensor.numel() for tensor in self.averaged_tensors) self.lock_looking_for_group = asyncio.Lock() self.lock_request_join_group = asyncio.Lock() @@ -83,8 +77,7 @@ def __repr__(self): return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \ f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})" - async def look_for_group(self, *, data_for_gather: bytes = b'', timeout: Optional[float] = None - ) -> Optional[AllReduceRunner]: + async def look_for_group(self, *, data_for_gather: bytes, timeout: Optional[float] = None) -> Optional[GroupInfo]: """ :param data_for_gather: optionally send this data to all peers in the next group and gather it from groupmates :param timeout: maximum time that may be spent looking for group (does not include allreduce itself) @@ -123,7 +116,7 @@ async def look_for_group(self, *, data_for_gather: bytes = b'', timeout: Optiona self.was_accepted_to_group.clear() self.data_for_gather = None - async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner: + async def _request_join_potential_leaders(self, timeout: Optional[float]) -> GroupInfo: """ Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """ async with self.potential_leaders.begin_search(self.group_key_manager, timeout, declare=not self.client_mode): while True: @@ -151,7 +144,7 @@ async def _request_join_potential_leaders(self, timeout: Optional[float]) -> All self.assembled_group.set_exception(e) raise e - async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpiration) -> Optional[AllReduceRunner]: + async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpiration) -> Optional[GroupInfo]: """ :param leader: request this peer to be your leader for allreduce :param expiration_time: inform leader that we intend to begin averaging before this expiration_time @@ -166,7 +159,6 @@ async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpirat leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True) call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest( endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time, - throughput=self.throughput if self.throughput is not None else -1.0, client_mode=self.client_mode, gather=self.data_for_gather)) message = await asyncio.wait_for(call.read(), timeout=self.request_timeout) @@ -255,11 +247,10 @@ async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED) return - allreduce_group = self.assembled_group.result() - yield averaging_pb2.MessageFromLeader( - code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id, - ordered_group_endpoints=allreduce_group.ordered_group_endpoints, part_sizes=allreduce_group.part_sizes, - gathered=allreduce_group.gathered, group_key_seed=allreduce_group.group_key_seed) + group_info = self.assembled_group.result() + yield averaging_pb2.MessageFromLeader(code=averaging_pb2.BEGIN_ALLREDUCE, group_id=group_info.group_id, + ordered_group_endpoints=group_info.endpoints, + gathered=group_info.gathered) except (concurrent.futures.CancelledError, asyncio.CancelledError): return # note: this is a compatibility layer for python3.7 except Exception as e: @@ -296,58 +287,39 @@ def _check_reasons_to_reject(self, request: averaging_pb2.JoinRequest) -> Option else: return None - async def leader_assemble_group(self) -> AllReduceRunner: - """ Form up all current followers into a group and prepare to _run_allreduce """ + async def leader_assemble_group(self) -> GroupInfo: + """ Form up all current followers into a group and gather metadata """ assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked() and not self.client_mode assert not self.assembled_group.done() - group_id = DHTID.generate().to_bytes() + group_id = DHTID.generate().to_bytes() # note: both groupd_id and the order of endpoints must be random ordered_group_endpoints = list(self.current_followers) ordered_group_endpoints.append(self.endpoint) random.shuffle(ordered_group_endpoints) - averager_throughputs, gathered = [], [] - for endpoint in ordered_group_endpoints: - if endpoint == self.endpoint: - averager_throughputs.append(self.throughput) - gathered.append(self.data_for_gather) - else: - follower_info = self.current_followers[endpoint] - throughput = follower_info.throughput if follower_info.throughput >= 0 else None - averager_throughput = throughput if not follower_info.client_mode else 0.0 - averager_throughputs.append(averager_throughput) - gathered.append(follower_info.gather if follower_info.gather else None) - - part_sizes = load_balance_peers(self.total_size, averager_throughputs, self.min_vector_size) - group_key_seed = random.randint(- 2 ** 31, 2 ** 31 - 1) - - logger.debug(f"{self.endpoint} - leader started allreduce for {len(ordered_group_endpoints)} peers.") - allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint, - ordered_group_endpoints=ordered_group_endpoints, part_sizes=part_sizes, - gathered=gathered, group_key_seed=group_key_seed, **self.allreduce_kwargs) - await self.group_key_manager.update_key_on_group_assembled(allreduce_group, is_leader=True) - self.assembled_group.set_result(allreduce_group) - return allreduce_group - - async def follower_assemble_group(self, leader: Endpoint, msg: averaging_pb2.MessageFromLeader) -> AllReduceRunner: - """ Prepare to run allreduce using a list of peers provided by our leader """ + gathered = tuple(self.data_for_gather if endpoint == self.endpoint else self.current_followers[endpoint].gather + for endpoint in ordered_group_endpoints) + + logger.debug(f"{self.endpoint} - assembled group of {len(ordered_group_endpoints)} peers.") + group_info = GroupInfo(group_id, tuple(ordered_group_endpoints), gathered) + await self.group_key_manager.update_key_on_group_assembled(group_info, is_leader=True) + self.assembled_group.set_result(group_info) + return group_info + + async def follower_assemble_group(self, leader: Endpoint, msg: averaging_pb2.MessageFromLeader) -> GroupInfo: + """ Form a group from using peers and metadata provided by our leader """ assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked() assert not self.assembled_group.done() assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})" - group_id, ordered_group_endpoints, part_sizes = msg.group_id, tuple(msg.ordered_group_endpoints), msg.part_sizes + group_id, ordered_group_endpoints = msg.group_id, msg.ordered_group_endpoints assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!" - assert len(ordered_group_endpoints) == len(part_sizes) == len(msg.gathered) - my_part_size = part_sizes[ordered_group_endpoints.index(self.endpoint)] - assert my_part_size == 0 or not self.client_mode, "Averager with client_mode=True cannot accept incoming data." - - logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.") - allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint, - ordered_group_endpoints=ordered_group_endpoints, - part_sizes=tuple(part_sizes), gathered=msg.gathered, - group_key_seed=int(msg.group_key_seed), **self.allreduce_kwargs) - await self.group_key_manager.update_key_on_group_assembled(allreduce_group) - self.assembled_group.set_result(allreduce_group) - return allreduce_group + assert len(ordered_group_endpoints) == len(msg.gathered) + + logger.debug(f"{self.endpoint} - follower assembled group with leader {leader}.") + group_info = GroupInfo(group_id, tuple(ordered_group_endpoints), tuple(msg.gathered)) + await self.group_key_manager.update_key_on_group_assembled(group_info) + self.assembled_group.set_result(group_info) + return group_info async def leader_disband_group(self): """ Kick out all followers immediately, optionally direct them to our new leader (if we found one) """ @@ -490,13 +462,5 @@ async def _declare_averager_periodically(self, key_manager: GroupKeyManager): looking_for_group=False) -def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes: - """ A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values """ - schema_dicts = [{field_name: str(field_value) - for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()} - for tensor in tensors] - return DHTID.generate(source=schema_dicts).to_bytes() - - class MatchmakingException(Exception): """ An internal exception that marks undesired edge cases during averaging """ diff --git a/hivemind/proto/averaging.proto b/hivemind/proto/averaging.proto index df4dc5cc8..8329f397b 100644 --- a/hivemind/proto/averaging.proto +++ b/hivemind/proto/averaging.proto @@ -35,8 +35,7 @@ message JoinRequest { bytes schema_hash = 2; // A hash that describes follower's tensors (shapes, num tensors, etc) double expiration = 3; // Follower would like to **begin** all_reduce by this point in time bytes gather = 4; // optional metadata that is gathered from all peers (e.g. batch size or current loss) - float throughput = 5; // Follower has this bandwidth for averaging (-1 = default) - bool client_mode = 6; // if True, the incoming averager is a client with no capacity for averaging + bool client_mode = 5; // if True, the incoming averager is a client with no capacity for averaging } message MessageFromLeader { @@ -44,9 +43,7 @@ message MessageFromLeader { bytes group_id = 2; // a unique identifier of this group, only valid until allreduce is finished/failed string suggested_leader = 3; // if peer is already in a group, it'll provide us with an endpoint of its leader repeated string ordered_group_endpoints = 4; // a sequence of peers, each responsible for one shard during averaging - repeated int32 part_sizes = 5; // a sequence of tensor parts assigned to each peer, same order as endpoints - repeated bytes gathered = 6; // metadata (gather) from all groupmates in the same order as their endoints - int32 group_key_seed = 7; // a random seed used by peers to update their group keys + repeated bytes gathered = 5; // metadata (gather) from all groupmates in the same order as their endoints } message AveragingData { @@ -54,6 +51,7 @@ message AveragingData { bytes group_id = 2; // a unique group identifier, same as in MessageFromLeader string endpoint = 3; // sender's rpc endpoint, used for coordination Tensor tensor_part = 4; // either peer's local tensor part (rpc input) or group average of this part (rpc output) + bytes metadata = 5; // reserved user-extendable metadata } message DownloadRequest {} diff --git a/tests/test_averaging.py b/tests/test_averaging.py index 5f30ec0c0..d7c862b3e 100644 --- a/tests/test_averaging.py +++ b/tests/test_averaging.py @@ -80,6 +80,42 @@ def test_allreduce_once(n_client_mode_peers): dht.shutdown() +@pytest.mark.forked +def test_allreduce_weighted(n_client_mode_peers: int = 2): + dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*') + + n_peers = 4 + should_listen = [False] * n_client_mode_peers + [True] * (n_peers - n_client_mode_peers) + random.shuffle(should_listen) + + tensors1 = [torch.randn(123), torch.zeros(3)] + tensors2 = [torch.rand(123), torch.ones(3)] + tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)] + tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2] + averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15, + prefix='mygroup', listen=listen, listen_on='127.0.0.1:*', + start=True) + for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)] + weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01)) + reference = [(tensors1[i] * weights[0] + tensors2[i] * weights[1] + tensors3[i] * weights[2] + + tensors4[i] * weights[3]) / sum(weights) for i in range(len(tensors1))] + + futures = [] + for averager, weight in zip(averagers, weights): + futures.append(averager.step(weight=weight, wait=False)) + for future in futures: + future.result() + + for future, averager in zip(futures, averagers): + with averager.get_tensors() as averaged_tensors: + for ref, our in zip(reference, averaged_tensors): + assert torch.allclose(ref, our, atol=1e-6) + + for averager in averagers: + averager.shutdown() + dht.shutdown() + + def compute_mean_std(averagers, unbiased=True): results = [] for averager in averagers: @@ -174,7 +210,7 @@ async def _accumulate(sender: Endpoint, recipient: Endpoint): sender_allreduce.register_averaged_part(source=recipient, averaged_part=averaged_part) await asyncio.wait({_accumulate(sender, recipient) for sender in peers for recipient in peers - if sender != recipient and recipient != "colab"}) + if recipient != "colab"}) reference_tensors = [ sum(tensors_by_peer[peer][i] for peer in peers) / len(peers)