diff --git a/.circleci/config.yml b/.circleci/config.yml index b1ab5978e..08ac194ae 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -19,7 +19,7 @@ jobs: command: pip install -e . name: setup - run: - command: pytest ./tests + command: while true; do pytest tests/test_averaging.py::test_allreduce_once[2-2]; done name: tests build-and-test-py38: docker: @@ -39,7 +39,7 @@ jobs: command: pip install -e . name: setup - run: - command: pytest ./tests + command: while true; do pytest tests/test_averaging.py::test_allreduce_once[2-2]; done name: tests build-and-test-py39: docker: @@ -59,7 +59,7 @@ jobs: command: pip install -e . name: setup - run: - command: pytest ./tests + command: while true; do pytest tests/test_averaging.py::test_allreduce_once[2-2]; done name: tests workflows: diff --git a/hivemind/client/averaging/__init__.py b/hivemind/client/averaging/__init__.py index 91df3be63..293b01c7a 100644 --- a/hivemind/client/averaging/__init__.py +++ b/hivemind/client/averaging/__init__.py @@ -20,7 +20,7 @@ import numpy as np from hivemind.dht import DHT, DHTID -from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts +from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode from hivemind.client.averaging.load_balancing import load_balance_peers from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException from hivemind.client.averaging.group_info import GroupInfo @@ -71,6 +71,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)] see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options :param kwargs: extra parameters forwarded to grpc.aio.server + :param auxiliary: if this flag is specified, averager.step will only assist others without sending + local tensors for averaging Example: @@ -95,17 +97,25 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE, throughput: Optional[float] = None, min_vector_size: int = 0, listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True, - channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs): + channel_options: Optional[Sequence[Tuple[str, Any]]] = None, auxiliary: bool = False, **kwargs): assert '.' not in prefix, "group prefix must be a string without trailing '.'" assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), \ "throughput must be a non-negative float32" if not is_power_of_two(target_group_size): logger.warning("It is recommended to set target_group_size to a power of 2.") assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits) + assert listen or not auxiliary, "auxiliary peers must accept incoming connections" super().__init__() self.dht = dht self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs + if not self.listen: + self.mode = AveragingMode.CLIENT + elif auxiliary: + self.mode = AveragingMode.AUX + else: + self.mode = AveragingMode.NODE + self.channel_options = channel_options self.daemon = daemon @@ -237,6 +247,8 @@ def step(self, gather: Optional[DataForGather] = None, weight: float = 1.0, time :returns: on success, update averaged_tensors and return group info; on failure, return None """ assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}" + if self.mode == AveragingMode.AUX and weight != 1: + logger.warning("Averager is running in auxiliary mode, weight is unused.") future, _future = MPFuture.make_pair() gather_binary = self.serializer.dumps(gather) # serialize here to avoid loading modules in the averager process self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight, @@ -253,7 +265,7 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float, while not future.done(): try: self._pending_group_assembled.clear() - data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary]) + data_for_gather = self.serializer.dumps([weight, self._throughput, self.mode.value, gather_binary]) group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather) if group_info is None: @@ -263,7 +275,8 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float, self._running_groups[group_id] = allreduce_runner self._pending_group_assembled.set() await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout) - await loop.run_in_executor(None, self.update_tensors, allreduce_runner) + if self.mode != AveragingMode.AUX: + await loop.run_in_executor(None, self.update_tensors, allreduce_runner) # averaging is finished, exit the loop future.set_result(allreduce_runner.gathered) @@ -293,19 +306,19 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float, async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner: """ Use a group description found by Matchmaking to form AllreduceRunner """ try: - weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered)) + weights, throughputs, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered)) user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered))) - # compute optimal part sizes from peer throughputs - incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)] + modes = tuple(map(AveragingMode, mode_ids)) + incoming_throughputs = [thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(throughputs, modes)] # TODO: replace with proper load balancing part_sizes = await asyncio.get_event_loop().run_in_executor( None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size) async with self.get_tensors_async() as averaged_tensors: return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint, ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes, - weights=weights, gathered=user_gathered, return_deltas=True, **kwargs) + weights=weights, gathered=user_gathered, return_deltas=True, modes=modes, **kwargs) except Exception as e: - raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {group_info}") + raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {weights, throughputs, modes, user_gathered}") def update_tensors(self, allreduce_group: AllReduceRunner): """ diff --git a/hivemind/client/averaging/allreduce.py b/hivemind/client/averaging/allreduce.py index 06fb3e9da..bb35c481a 100644 --- a/hivemind/client/averaging/allreduce.py +++ b/hivemind/client/averaging/allreduce.py @@ -1,5 +1,6 @@ import asyncio -from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any +from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any, Optional +from enum import Enum import grpc import torch @@ -14,6 +15,12 @@ logger = get_logger(__name__) +class AveragingMode(Enum): + NODE = 0 + CLIENT = 1 + AUX = 2 + + class AllReduceProtocol: """ An internal class that runs butterfly AllReduce in a predefined group of averagers @@ -27,12 +34,16 @@ class AllReduceProtocol: """ def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint, - ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False): + ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False, + modes: Optional[Sequence[AveragingMode]] = None): assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group" self.group_id, self.endpoint = group_id, endpoint self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes - self.client_mode_endpoints = {endpoint for endpoint, part_size in zip(self.ordered_group_endpoints, part_sizes) - if part_size == 0} + if modes is None: + modes = [AveragingMode.CLIENT if part_size == 0 else AveragingMode.NODE for part_size in part_sizes] + assert any(mode != AveragingMode.CLIENT for mode in modes), "Cannot run allreduce without reducers." + self.peer_modes = dict(zip(ordered_group_endpoints, modes)) + self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes))) self.tensor_shapes = tuple(tensor.shape for tensor in tensors) self.return_deltas = return_deltas @@ -43,8 +54,14 @@ def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoi self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future() # will be set to [accumulator / group size] self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {} # averaged chunks from all peers will be put here self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future() # final result or exception - for endpoint in self.client_mode_endpoints: - self.averaged_tensor_parts[endpoint] = torch.tensor([]) + + self.num_senders = len([mode for mode in modes if mode != AveragingMode.AUX]) + + if self.num_senders == 0: + self.future.set_result(None) + for endpoint, mode in self.peer_modes.items(): + if mode == AveragingMode.CLIENT: + self.averaged_tensor_parts[endpoint] = torch.tensor([]) def __repr__(self): return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})" @@ -65,20 +82,24 @@ async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor, wei assert not self.future.done(), f"already finished allreduce: {self.future}" assert source in self.local_tensor_parts, "unexpected source, not a part of current group" assert source not in self.accumulated_from, "duplicate source, already received that part" - assert not self.endpoint in self.client_mode_endpoints, f"{self.endpoint} is in client mode" + assert self.peer_modes[self.endpoint] != AveragingMode.CLIENT, f"{self.endpoint} is in AveragingMode.client mode" assert isinstance(weight, (int, float)) and weight > 0, "averaging weights must be a non-negative int/float" - logger.debug(f"{self} - accumulating tensor part from {source}") + logger.debug(f"{self} - accumulating tensor part from {source}") self.accumulator.add_(remote_part, alpha=weight) self.denominator += weight self.accumulated_from.add(source) - assert len(self.accumulated_from) <= self.group_size - if len(self.accumulated_from) == len(self.local_tensor_parts): + assert len(self.accumulated_from) <= self.num_senders + if len(self.accumulated_from) == self.num_senders: average_result = self.accumulator.div_(self.denominator) - self.register_averaged_part(self.endpoint, average_result) self.averaged_part.set_result(average_result) + if self.peer_modes[self.endpoint] == AveragingMode.AUX: + self.future.set_result(None) # auxiliary mode has finished averaging + else: + self.register_averaged_part(self.endpoint, average_result) + return await self.averaged_part def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor): @@ -87,6 +108,7 @@ def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor): assert source not in self.averaged_tensor_parts, "already registered the average from this peer" assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch" assert averaged_part.dtype == self.local_tensor_parts[source].dtype, "averaged part dtype mismatch" + assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers do not have local tensors for sending" logger.debug(f"{self} - receiving averaged tensor part from {source}") self.averaged_tensor_parts[source] = averaged_part if len(self.averaged_tensor_parts) == len(self.local_tensor_parts): @@ -133,9 +155,9 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint, ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType, chunk_size_bytes: int, part_sizes: Tuple[int, ...], weights: Tuple[float, ...], - gathered: Dict[Endpoint, Any], return_deltas: bool = False): + gathered: Dict[Endpoint, Any], return_deltas: bool = False, **kwargs): super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes, - ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas) + ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas, **kwargs) self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered self.peer_weights = dict(zip(self.ordered_group_endpoints, weights)) @@ -144,6 +166,7 @@ def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAver async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor: """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """ + assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers are disallowed from sending tensors" if peer_endpoint == self.endpoint: return await self.accumulate_part(self.endpoint, local_part, weight=self.peer_weights[self.endpoint]) serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False) @@ -182,9 +205,10 @@ async def run(self) -> Sequence[torch.Tensor]: send allreduce requests to all peers and collect results, return the averaged tensor (or deltas) """ try: - await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer]) - for i, peer in enumerate(self.ordered_group_endpoints) - if peer not in self.client_mode_endpoints)) + if self.peer_modes[self.endpoint] != AveragingMode.AUX: + await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer]) + for i, peer in enumerate(self.ordered_group_endpoints) + if self.peer_modes[peer] != AveragingMode.CLIENT)) return await self except BaseException as e: code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR diff --git a/hivemind/client/averaging/matchmaking.py b/hivemind/client/averaging/matchmaking.py index 8ec866e51..402be2369 100644 --- a/hivemind/client/averaging/matchmaking.py +++ b/hivemind/client/averaging/matchmaking.py @@ -153,6 +153,7 @@ async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpirat :note: this function does not guarantee that your group leader is the same as :leader: parameter The originally specified leader can disband group and redirect us to a different leader """ + print(f"{self.endpoint} - REQUEST TO {leader}") assert self.is_looking_for_group and self.current_leader is None call: Optional[grpc.aio.UnaryStreamCall] = None try: @@ -165,14 +166,19 @@ async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpirat if message.code == averaging_pb2.ACCEPTED: logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers") + print(f"{self.endpoint} - joining the group of {leader}; waiting for peers") + self.current_leader = leader self.was_accepted_to_group.set() if len(self.current_followers) > 0: await self.leader_disband_group() + print(f"{self.endpoint} - DISBANDED GROUP") + if message.code != averaging_pb2.ACCEPTED: code = averaging_pb2.MessageCode.Name(message.code) logger.debug(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}") + print(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}") return None async with self.potential_leaders.pause_search(): @@ -180,23 +186,30 @@ async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpirat message = await asyncio.wait_for(call.read(), time_to_expiration + self.request_timeout) if message.code == averaging_pb2.BEGIN_ALLREDUCE: + print(f"{self.endpoint} - beginning alreduce") + async with self.lock_request_join_group: return await self.follower_assemble_group(leader, message) + else: + print(f"{self.endpoint} - NOT beginning alreduce due to receiving code {message.code}") if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED): if message.suggested_leader and message.suggested_leader != self.endpoint: + print(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}") logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}") self.current_leader = None call.cancel() return await self.request_join_group(message.suggested_leader, expiration_time) else: + print(f"{self} - leader disbanded group/") logger.debug(f"{self} - leader disbanded group") return None - + print(f"{self} - unexpected message from leader: {averaging_pb2.MessageCode.Name(message.code)}") + logger.debug(f"{self} - unexpected message from leader: {averaging_pb2.MessageCode.Name(message.code)}") return None except asyncio.TimeoutError: - logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}") + print(f"{self} - potential leader {leader} did not respond within {self.request_timeout}") if call is not None: call.cancel() return None @@ -214,17 +227,23 @@ async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc ) -> AsyncIterator[averaging_pb2.MessageFromLeader]: """ accept or reject a join request from another averager; if accepted, run him through allreduce steps """ try: + print(f"{self.endpoint} - incoming request from {request.endpoint} (time={get_dht_time()})") async with self.lock_request_join_group: reason_to_reject = self._check_reasons_to_reject(request) if reason_to_reject is not None: + print(f"{self.endpoint} - rejected request from {request.endpoint}") yield reason_to_reject return self.current_followers[request.endpoint] = request + print(f"{self.endpoint} - accepted request from {request.endpoint}") yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED) if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done(): # outcome 1: we have assembled a full group and are ready for allreduce + print(f"{self.endpoint} - beginning allreduce because assembled full group: {self.current_followers} (plus {self.endpoint})" + f" target size = {self.target_group_size}") + await self.leader_assemble_group() # wait for the group to be assembled or disbanded @@ -237,8 +256,12 @@ async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc pass # this covers a rare case when the group is assembled while the event loop was busy. elif len(self.current_followers) + 1 >= self.min_group_size and self.is_looking_for_group: # outcome 2: the time is up, run allreduce with what we have or disband + print(f"{self.endpoint} - beginning allreduce because time is up, group: {self.current_followers} (plus {self.endpoint})" + f"target size = {self.target_group_size} (time={get_dht_time()})") await self.leader_assemble_group() else: + print(f"{self.endpoint} - disbanging group because time is up, group: {self.current_followers} (plus {self.endpoint})" + f"target size = {self.target_group_size} (time={get_dht_time()})") await self.leader_disband_group() if self.was_accepted_to_group.is_set() or not self.assembled_group.done() \ @@ -265,6 +288,8 @@ async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc finally: # note: this code is guaranteed to run even if the coroutine is destroyed prematurely self.current_followers.pop(request.endpoint, None) self.follower_was_discarded.set() + print(f"{self.endpoint} - finished processing request from {request.endpoint}") + def _check_reasons_to_reject(self, request: averaging_pb2.JoinRequest) -> Optional[averaging_pb2.MessageFromLeader]: """ :returns: if accepted, return None, otherwise return a reason for rejection """ @@ -386,12 +411,14 @@ async def pop_next_leader(self) -> Endpoint: """ Remove and return the next most suitable leader or throw an exception if reached timeout """ assert self.running.is_set(), "Not running search at the moment" while True: + print(f"QQ{self.endpoint} - current queue = {self.leader_queue.data}") maybe_next_leader, entry = self.leader_queue.top() if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time: self.update_triggered.set() - if maybe_next_leader is None or entry.expiration_time >= self.declared_expiration_time: + if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader) > (self.declared_expiration_time, self.endpoint): + print(f"QQ{self.endpoint} - awaiting queue update") await asyncio.wait({self.update_finished.wait(), self.declared_expiration.wait()}, return_when=asyncio.FIRST_COMPLETED) self.declared_expiration.clear() @@ -403,6 +430,8 @@ async def pop_next_leader(self) -> Endpoint: del self.leader_queue[maybe_next_leader] self.past_attempts.add((maybe_next_leader, entry.expiration_time)) + print(f"QQ{self.endpoint} - yielding {maybe_next_leader}") + return maybe_next_leader @property diff --git a/tests/test_averaging.py b/tests/test_averaging.py index 65bfea96d..a69b17f8d 100644 --- a/tests/test_averaging.py +++ b/tests/test_averaging.py @@ -5,7 +5,7 @@ import torch import pytest import hivemind -from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts +from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts, AveragingMode from hivemind.client.averaging.load_balancing import load_balance_peers from hivemind.client.averaging.key_manager import GroupKeyManager from hivemind.utils import Endpoint @@ -41,45 +41,57 @@ async def test_key_manager(): assert len(q5) == 0 -@pytest.mark.forked -@pytest.mark.parametrize("n_client_mode_peers", [0, 2]) -def test_allreduce_once(n_client_mode_peers): +def _test_allreduce_once(n_clients, n_aux): dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*') n_peers = 4 - should_listen = [False] * n_client_mode_peers + [True] * (n_peers - n_client_mode_peers) - random.shuffle(should_listen) + modes = [AveragingMode.CLIENT] * n_clients + [AveragingMode.AUX] * n_aux + [AveragingMode.NODE] * (n_peers - n_clients - n_aux) + # random.shuffle(modes) tensors1 = [torch.randn(123), torch.zeros(3)] tensors2 = [torch.rand(123), torch.ones(3)] tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)] tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2] + peer_tensors = [tensors1, tensors2, tensors3, tensors4] - reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))] + reference = [sum(tensors[i] for tensors, mode in zip(peer_tensors, modes) + if mode != AveragingMode.AUX) / max(1, n_peers - n_aux) for i in range(len(tensors1))] averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15, - prefix='mygroup', listen=listen, listen_on='127.0.0.1:*', - start=True) - for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)] - + prefix='mygroup', listen=mode != AveragingMode.CLIENT, listen_on='127.0.0.1:*', + auxiliary=(mode == AveragingMode.AUX), start=True) + for tensors, mode in zip(peer_tensors, modes)] futures = [] for averager in averagers: futures.append(averager.step(wait=False)) for future in futures: result = future.result() for averager in averagers: - assert averager.endpoint in result - + assert averager.endpoint in result, f"{modes}" for averager in averagers: - with averager.get_tensors() as averaged_tensors: - for ref, our in zip(reference, averaged_tensors): - assert torch.allclose(ref, our, atol=1e-6) + if averager.mode != AveragingMode.AUX: + with averager.get_tensors() as averaged_tensors: + for ref, our in zip(reference, averaged_tensors): + assert torch.allclose(ref, our, atol=1e-6) for averager in averagers: averager.shutdown() dht.shutdown() +@pytest.mark.forked +@pytest.mark.parametrize("n_clients", [0, 1, 2]) +@pytest.mark.parametrize("n_aux", [0, 1, 2]) +def test_allreduce_once(n_clients, n_aux): + _test_allreduce_once(n_clients, n_aux) + + +@pytest.mark.forked +@pytest.mark.parametrize("n_clients, n_aux", [(0, 4), (1, 3), (0, 3)]) +def test_allreduce_once_edge_cases(n_clients, n_aux): + _test_allreduce_once(n_clients, n_aux) + + @pytest.mark.forked def test_allreduce_weighted(n_client_mode_peers: int = 2): dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')