diff --git a/hivemind/client/averaging/__init__.py b/hivemind/client/averaging/__init__.py index 91df3be63..6bc642485 100644 --- a/hivemind/client/averaging/__init__.py +++ b/hivemind/client/averaging/__init__.py @@ -20,7 +20,7 @@ import numpy as np from hivemind.dht import DHT, DHTID -from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts +from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode from hivemind.client.averaging.load_balancing import load_balance_peers from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException from hivemind.client.averaging.group_info import GroupInfo @@ -71,6 +71,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)] see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options :param kwargs: extra parameters forwarded to grpc.aio.server + :param auxiliary: if this flag is specified, averager.step will only assist others without sending + local tensors for averaging + :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten + with averager.allow_state_sharing = True / False Example: @@ -94,6 +98,7 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0, compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE, throughput: Optional[float] = None, min_vector_size: int = 0, + auxiliary: bool = False, allow_state_sharing: Optional[bool] = None, listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True, channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs): assert '.' not in prefix, "group prefix must be a string without trailing '.'" @@ -102,10 +107,18 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: if not is_power_of_two(target_group_size): logger.warning("It is recommended to set target_group_size to a power of 2.") assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits) + assert listen or not auxiliary, "auxiliary peers must accept incoming connections" super().__init__() self.dht = dht self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs + if not self.listen: + self.mode = AveragingMode.CLIENT + elif auxiliary: + self.mode = AveragingMode.AUX + else: + self.mode = AveragingMode.NODE + self.channel_options = channel_options self.daemon = daemon @@ -129,6 +142,10 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: self._pipe, self.pipe = mp.Pipe(duplex=True) # a control pipe used to communicate with a background process self._port = mp.Value(ctypes.c_uint32, 0) # assigned when averager starts, accessible via self.port + + self._allow_state_sharing = mp.Value(ctypes.c_bool, 0) + self.allow_state_sharing = (listen and not auxiliary) if allow_state_sharing is None else allow_state_sharing + self._averager_endpoint: Optional[Endpoint] = None if not self.listen: self._averager_endpoint = f'client::{uuid.uuid4()}' @@ -146,6 +163,18 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: def port(self) -> Optional[Port]: return self._port.value if self._port.value != 0 else None + @property + def allow_state_sharing(self) -> bool: + """ if set to True, other peers can download this peer's state """ + return bool(self._allow_state_sharing.value) + + @allow_state_sharing.setter + def allow_state_sharing(self, value: bool): + if value is True and not self.listen: + logger.warning("Cannot allow state sharing: averager in client mode (listen=False) cannot share its state.") + else: + self._allow_state_sharing.value = value + @property def endpoint(self) -> Optional[Endpoint]: if self.listen and self._averager_endpoint is None: @@ -236,7 +265,11 @@ def step(self, gather: Optional[DataForGather] = None, weight: float = 1.0, time :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background. :returns: on success, update averaged_tensors and return group info; on failure, return None """ - assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}" + if self.mode == AveragingMode.AUX and weight != 1: + logger.warning("Averager is running in auxiliary mode, weight is unused.") + else: + assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}" + future, _future = MPFuture.make_pair() gather_binary = self.serializer.dumps(gather) # serialize here to avoid loading modules in the averager process self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight, @@ -253,7 +286,7 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float, while not future.done(): try: self._pending_group_assembled.clear() - data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary]) + data_for_gather = self.serializer.dumps([weight, self._throughput, self.mode.value, gather_binary]) group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather) if group_info is None: @@ -263,7 +296,8 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float, self._running_groups[group_id] = allreduce_runner self._pending_group_assembled.set() await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout) - await loop.run_in_executor(None, self.update_tensors, allreduce_runner) + if self.mode != AveragingMode.AUX: + await loop.run_in_executor(None, self.update_tensors, allreduce_runner) # averaging is finished, exit the loop future.set_result(allreduce_runner.gathered) @@ -293,19 +327,19 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float, async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner: """ Use a group description found by Matchmaking to form AllreduceRunner """ try: - weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered)) + weights, throughputs, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered)) user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered))) - # compute optimal part sizes from peer throughputs - incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)] + modes = tuple(map(AveragingMode, mode_ids)) + incoming_throughputs = [thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(throughputs, modes)] # TODO: replace with proper load balancing part_sizes = await asyncio.get_event_loop().run_in_executor( None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size) async with self.get_tensors_async() as averaged_tensors: return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint, ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes, - weights=weights, gathered=user_gathered, return_deltas=True, **kwargs) + weights=weights, gathered=user_gathered, return_deltas=True, modes=modes, **kwargs) except Exception as e: - raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {group_info}") + raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {weights, throughputs, modes, user_gathered}") def update_tensors(self, allreduce_group: AllReduceRunner): """ @@ -366,10 +400,11 @@ async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.Averaging async def _declare_for_download_periodically(self): download_key = f'{self._matchmaking.group_key_manager.prefix}.all_averagers' while True: - asyncio.create_task(asyncio.wait_for(self.dht.store( - download_key, subkey=self.endpoint, value=self.last_updated, - expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True), - timeout=self._matchmaking.averaging_expiration)) + if self.allow_state_sharing: + asyncio.create_task(asyncio.wait_for(self.dht.store( + download_key, subkey=self.endpoint, value=self.last_updated, + expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True), + timeout=self._matchmaking.averaging_expiration)) await asyncio.sleep(self._matchmaking.averaging_expiration) async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext @@ -381,6 +416,8 @@ async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, conte - serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics """ + if not self.allow_state_sharing: + return # deny request and direct peer to the next prospective averager chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES) metadata, tensors = await self._get_current_state_from_host_process() @@ -452,6 +489,11 @@ async def _load_state_from_peers(self, future: MPFuture): current_tensor_parts.append(message.tensor_part) if current_tensor_parts: tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts))) + + if not metadata: + logger.debug(f"Peer {peer} did not send its state.") + continue + logger.info(f"Finished downloading state from {peer}") future.set_result((metadata, tensors)) self.last_updated = get_dht_time() diff --git a/hivemind/client/averaging/allreduce.py b/hivemind/client/averaging/allreduce.py index 06fb3e9da..bb35c481a 100644 --- a/hivemind/client/averaging/allreduce.py +++ b/hivemind/client/averaging/allreduce.py @@ -1,5 +1,6 @@ import asyncio -from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any +from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any, Optional +from enum import Enum import grpc import torch @@ -14,6 +15,12 @@ logger = get_logger(__name__) +class AveragingMode(Enum): + NODE = 0 + CLIENT = 1 + AUX = 2 + + class AllReduceProtocol: """ An internal class that runs butterfly AllReduce in a predefined group of averagers @@ -27,12 +34,16 @@ class AllReduceProtocol: """ def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint, - ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False): + ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False, + modes: Optional[Sequence[AveragingMode]] = None): assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group" self.group_id, self.endpoint = group_id, endpoint self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes - self.client_mode_endpoints = {endpoint for endpoint, part_size in zip(self.ordered_group_endpoints, part_sizes) - if part_size == 0} + if modes is None: + modes = [AveragingMode.CLIENT if part_size == 0 else AveragingMode.NODE for part_size in part_sizes] + assert any(mode != AveragingMode.CLIENT for mode in modes), "Cannot run allreduce without reducers." + self.peer_modes = dict(zip(ordered_group_endpoints, modes)) + self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes))) self.tensor_shapes = tuple(tensor.shape for tensor in tensors) self.return_deltas = return_deltas @@ -43,8 +54,14 @@ def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoi self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future() # will be set to [accumulator / group size] self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {} # averaged chunks from all peers will be put here self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future() # final result or exception - for endpoint in self.client_mode_endpoints: - self.averaged_tensor_parts[endpoint] = torch.tensor([]) + + self.num_senders = len([mode for mode in modes if mode != AveragingMode.AUX]) + + if self.num_senders == 0: + self.future.set_result(None) + for endpoint, mode in self.peer_modes.items(): + if mode == AveragingMode.CLIENT: + self.averaged_tensor_parts[endpoint] = torch.tensor([]) def __repr__(self): return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})" @@ -65,20 +82,24 @@ async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor, wei assert not self.future.done(), f"already finished allreduce: {self.future}" assert source in self.local_tensor_parts, "unexpected source, not a part of current group" assert source not in self.accumulated_from, "duplicate source, already received that part" - assert not self.endpoint in self.client_mode_endpoints, f"{self.endpoint} is in client mode" + assert self.peer_modes[self.endpoint] != AveragingMode.CLIENT, f"{self.endpoint} is in AveragingMode.client mode" assert isinstance(weight, (int, float)) and weight > 0, "averaging weights must be a non-negative int/float" - logger.debug(f"{self} - accumulating tensor part from {source}") + logger.debug(f"{self} - accumulating tensor part from {source}") self.accumulator.add_(remote_part, alpha=weight) self.denominator += weight self.accumulated_from.add(source) - assert len(self.accumulated_from) <= self.group_size - if len(self.accumulated_from) == len(self.local_tensor_parts): + assert len(self.accumulated_from) <= self.num_senders + if len(self.accumulated_from) == self.num_senders: average_result = self.accumulator.div_(self.denominator) - self.register_averaged_part(self.endpoint, average_result) self.averaged_part.set_result(average_result) + if self.peer_modes[self.endpoint] == AveragingMode.AUX: + self.future.set_result(None) # auxiliary mode has finished averaging + else: + self.register_averaged_part(self.endpoint, average_result) + return await self.averaged_part def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor): @@ -87,6 +108,7 @@ def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor): assert source not in self.averaged_tensor_parts, "already registered the average from this peer" assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch" assert averaged_part.dtype == self.local_tensor_parts[source].dtype, "averaged part dtype mismatch" + assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers do not have local tensors for sending" logger.debug(f"{self} - receiving averaged tensor part from {source}") self.averaged_tensor_parts[source] = averaged_part if len(self.averaged_tensor_parts) == len(self.local_tensor_parts): @@ -133,9 +155,9 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint, ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType, chunk_size_bytes: int, part_sizes: Tuple[int, ...], weights: Tuple[float, ...], - gathered: Dict[Endpoint, Any], return_deltas: bool = False): + gathered: Dict[Endpoint, Any], return_deltas: bool = False, **kwargs): super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes, - ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas) + ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas, **kwargs) self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered self.peer_weights = dict(zip(self.ordered_group_endpoints, weights)) @@ -144,6 +166,7 @@ def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAver async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor: """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """ + assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers are disallowed from sending tensors" if peer_endpoint == self.endpoint: return await self.accumulate_part(self.endpoint, local_part, weight=self.peer_weights[self.endpoint]) serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False) @@ -182,9 +205,10 @@ async def run(self) -> Sequence[torch.Tensor]: send allreduce requests to all peers and collect results, return the averaged tensor (or deltas) """ try: - await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer]) - for i, peer in enumerate(self.ordered_group_endpoints) - if peer not in self.client_mode_endpoints)) + if self.peer_modes[self.endpoint] != AveragingMode.AUX: + await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer]) + for i, peer in enumerate(self.ordered_group_endpoints) + if self.peer_modes[peer] != AveragingMode.CLIENT)) return await self except BaseException as e: code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR diff --git a/hivemind/client/averaging/matchmaking.py b/hivemind/client/averaging/matchmaking.py index 8ec866e51..eb50fdf0f 100644 --- a/hivemind/client/averaging/matchmaking.py +++ b/hivemind/client/averaging/matchmaking.py @@ -391,7 +391,7 @@ async def pop_next_leader(self) -> Endpoint: if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time: self.update_triggered.set() - if maybe_next_leader is None or entry.expiration_time >= self.declared_expiration_time: + if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader) > (self.declared_expiration_time, self.endpoint): await asyncio.wait({self.update_finished.wait(), self.declared_expiration.wait()}, return_when=asyncio.FIRST_COMPLETED) self.declared_expiration.clear() diff --git a/hivemind/optim/collaborative.py b/hivemind/optim/collaborative.py index 5bfde8a4f..72dc27c92 100644 --- a/hivemind/optim/collaborative.py +++ b/hivemind/optim/collaborative.py @@ -191,7 +191,7 @@ def step(self, batch_size: Optional[int] = None, **kwargs): with self.lock_local_progress: self.local_samples_accumulated += batch_size self.local_steps_accumulated += 1 - self.performance_ema.update(num_processed=self.batch_size_per_step) + self.performance_ema.update(num_processed=batch_size) self.should_report_progress.set() if not self.collaboration_state.ready_for_step: diff --git a/tests/test_averaging.py b/tests/test_averaging.py index 65bfea96d..50a185acf 100644 --- a/tests/test_averaging.py +++ b/tests/test_averaging.py @@ -5,7 +5,7 @@ import torch import pytest import hivemind -from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts +from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts, AveragingMode from hivemind.client.averaging.load_balancing import load_balance_peers from hivemind.client.averaging.key_manager import GroupKeyManager from hivemind.utils import Endpoint @@ -41,26 +41,26 @@ async def test_key_manager(): assert len(q5) == 0 -@pytest.mark.forked -@pytest.mark.parametrize("n_client_mode_peers", [0, 2]) -def test_allreduce_once(n_client_mode_peers): +def _test_allreduce_once(n_clients, n_aux): dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*') n_peers = 4 - should_listen = [False] * n_client_mode_peers + [True] * (n_peers - n_client_mode_peers) - random.shuffle(should_listen) - + modes = [AveragingMode.CLIENT] * n_clients + [AveragingMode.AUX] * n_aux + [AveragingMode.NODE] * (n_peers - n_clients - n_aux) + random.shuffle(modes) + tensors1 = [torch.randn(123), torch.zeros(3)] tensors2 = [torch.rand(123), torch.ones(3)] tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)] tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2] - - reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))] + peer_tensors = [tensors1, tensors2, tensors3, tensors4] + + reference = [sum(tensors[i] for tensors, mode in zip(peer_tensors, modes) + if mode != AveragingMode.AUX) / max(1, n_peers - n_aux) for i in range(len(tensors1))] averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15, - prefix='mygroup', listen=listen, listen_on='127.0.0.1:*', - start=True) - for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)] + prefix='mygroup', listen=mode != AveragingMode.CLIENT, listen_on='127.0.0.1:*', + auxiliary=mode == AveragingMode.AUX, start=True) + for tensors, mode in zip(peer_tensors, modes)] futures = [] for averager in averagers: @@ -71,15 +71,29 @@ def test_allreduce_once(n_client_mode_peers): assert averager.endpoint in result for averager in averagers: - with averager.get_tensors() as averaged_tensors: - for ref, our in zip(reference, averaged_tensors): - assert torch.allclose(ref, our, atol=1e-6) + if averager.mode != AveragingMode.AUX: + with averager.get_tensors() as averaged_tensors: + for ref, our in zip(reference, averaged_tensors): + assert torch.allclose(ref, our, atol=1e-6) for averager in averagers: averager.shutdown() dht.shutdown() +@pytest.mark.forked +@pytest.mark.parametrize("n_clients", [0, 1, 2]) +@pytest.mark.parametrize("n_aux", [0, 1, 2]) +def test_allreduce_once(n_clients, n_aux): + _test_allreduce_once(n_clients, n_aux) + + +@pytest.mark.forked +@pytest.mark.parametrize("n_clients, n_aux", [(0, 4), (1, 3), (0, 3)]) +def test_allreduce_once_edge_cases(n_clients, n_aux): + _test_allreduce_once(n_clients, n_aux) + + @pytest.mark.forked def test_allreduce_weighted(n_client_mode_peers: int = 2): dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*') @@ -369,6 +383,13 @@ def get_current_state(self): assert got_metadata == super_metadata assert all(map(torch.allclose, got_tensors, super_tensors)) + averager1.allow_state_sharing = False + assert averager2.load_state_from_peers() is None + averager1.allow_state_sharing = True + got_metadata, got_tensors = averager2.load_state_from_peers() + assert num_calls == 3 + assert got_metadata == super_metadata + @pytest.mark.forked def test_getset_bits():