Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support auxiliary participants in AllReduceProtocol #260

Merged
merged 25 commits into from
May 21, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions hivemind/client/averaging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np

from hivemind.dht import DHT, DHTID
from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts
from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, Mode
from hivemind.client.averaging.load_balancing import load_balance_peers
from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
from hivemind.client.averaging.group_info import GroupInfo
Expand Down Expand Up @@ -95,17 +95,22 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start:
compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
throughput: Optional[float] = None, min_vector_size: int = 0,
listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
channel_options: Optional[Sequence[Tuple[str, Any]]] = None, auxiliary: bool = False, **kwargs):
assert '.' not in prefix, "group prefix must be a string without trailing '.'"
assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), \
"throughput must be a non-negative float32"
if not is_power_of_two(target_group_size):
logger.warning("It is recommended to set target_group_size to a power of 2.")
assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
assert listen or not auxiliary, f"auxiliary peers must accept incoming connections"

super().__init__()
self.dht = dht
self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs
self.mode = Mode.NODE if self.listen else Mode.CLIENT
if auxiliary:
self.mode = Mode.AUX

self.channel_options = channel_options
self.daemon = daemon

Expand Down Expand Up @@ -237,6 +242,8 @@ def step(self, gather: Optional[DataForGather] = None, weight: float = 1.0, time
:returns: on success, update averaged_tensors and return group info; on failure, return None
"""
assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
if self.mode == Mode.AUX and weight != 1:
logger.warning("Averager is running in auxiliary mode, weight is unused.")
future, _future = MPFuture.make_pair()
gather_binary = self.serializer.dumps(gather) # serialize here to avoid loading modules in the averager process
self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
Expand All @@ -253,7 +260,7 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
while not future.done():
try:
self._pending_group_assembled.clear()
data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
data_for_gather = self.serializer.dumps([weight, self._throughput, self.mode.value, gather_binary])
group_info = await self._matchmaking.look_for_group(timeout=timeout,
data_for_gather=data_for_gather)
if group_info is None:
Expand All @@ -263,7 +270,8 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
self._running_groups[group_id] = allreduce_runner
self._pending_group_assembled.set()
await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
if self.mode != Mode.AUX:
await loop.run_in_executor(None, self.update_tensors, allreduce_runner)

# averaging is finished, exit the loop
future.set_result(allreduce_runner.gathered)
Expand Down Expand Up @@ -293,19 +301,20 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
""" Use a group description found by Matchmaking to form AllreduceRunner """
try:
weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
weights, throughputs, modes_ix, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name of modes_ix is not clear, better to reflect the meaning directly (at leastmode_inds)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's still worth changing, not sure if _ix is a common suffix for indices

user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))

modes = tuple(map(Mode, modes_ix))

# compute optimal part sizes from peer throughputs
incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]
incoming_throughputs = [thr if mode != Mode.CLIENT else 0.0 for thr, mode in zip(throughputs, modes)] # TODO: replace with proper load balancing
part_sizes = await asyncio.get_event_loop().run_in_executor(
None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
async with self.get_tensors_async() as averaged_tensors:
return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint,
ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes,
weights=weights, gathered=user_gathered, return_deltas=True, **kwargs)
weights=weights, gathered=user_gathered, return_deltas=True, modes=modes, **kwargs)
except Exception as e:
raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {group_info}")
raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {weights, throughputs, modes, user_gathered}")

def update_tensors(self, allreduce_group: AllReduceRunner):
"""
Expand Down
58 changes: 43 additions & 15 deletions hivemind/client/averaging/allreduce.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any
from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any, Optional
from enum import Enum

import grpc
import torch
Expand All @@ -14,6 +15,10 @@
logger = get_logger(__name__)


class Mode(Enum):
NODE, CLIENT, AUX = 0, 1, 2


class AllReduceProtocol:
"""
An internal class that runs butterfly AllReduce in a predefined group of averagers
Expand All @@ -27,12 +32,16 @@ class AllReduceProtocol:
"""

def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False):
ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False,
modes: Optional[Sequence[Mode]] = None):
assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
self.group_id, self.endpoint = group_id, endpoint
self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes
self.client_mode_endpoints = {endpoint for endpoint, part_size in zip(self.ordered_group_endpoints, part_sizes)
if part_size == 0}
if modes is None:
modes = [Mode.CLIENT if part_size == 0 else Mode.NODE for part_size in part_sizes]
assert any(mode != Mode.CLIENT for mode in modes), "Cannot run allreduce without reducers."
self.peer_modes = dict(zip(ordered_group_endpoints, modes))

self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes)))
self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
self.return_deltas = return_deltas
Expand All @@ -43,8 +52,14 @@ def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoi
self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future() # will be set to [accumulator / group size]
self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {} # averaged chunks from all peers will be put here
self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future() # final result or exception
for endpoint in self.client_mode_endpoints:
self.averaged_tensor_parts[endpoint] = torch.tensor([])

self.num_senders = len([mode for mode in modes if mode != Mode.AUX])

if self.num_senders == 0:
self.future.set_result(None)
for endpoint, mode in self.peer_modes.items():
if mode == Mode.CLIENT:
self.averaged_tensor_parts[endpoint] = torch.tensor([])

def __repr__(self):
return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
Expand All @@ -65,20 +80,26 @@ async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor, wei
assert not self.future.done(), f"already finished allreduce: {self.future}"
assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
assert source not in self.accumulated_from, "duplicate source, already received that part"
assert not self.endpoint in self.client_mode_endpoints, f"{self.endpoint} is in client mode"
assert self.peer_modes[self.endpoint] != Mode.CLIENT, f"{self.endpoint} is in Mode.client mode"
assert isinstance(weight, (int, float)) and weight > 0, "averaging weights must be a non-negative int/float"
logger.debug(f"{self} - accumulating tensor part from {source}")

print(end=f"{self} - accumulating tensor part from {source}\n")

self.accumulator.add_(remote_part, alpha=weight)
self.denominator += weight
self.accumulated_from.add(source)

assert len(self.accumulated_from) <= self.group_size
if len(self.accumulated_from) == len(self.local_tensor_parts):
assert len(self.accumulated_from) <= self.num_senders
if len(self.accumulated_from) == self.num_senders:
average_result = self.accumulator.div_(self.denominator)
self.register_averaged_part(self.endpoint, average_result)
self.averaged_part.set_result(average_result)

if self.peer_modes[self.endpoint] == Mode.AUX:
self.future.set_result(None) # auxiliary mode has finished averaging
else:
self.register_averaged_part(self.endpoint, average_result)

return await self.averaged_part

def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
Expand All @@ -87,6 +108,7 @@ def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
assert source not in self.averaged_tensor_parts, "already registered the average from this peer"
assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch"
assert averaged_part.dtype == self.local_tensor_parts[source].dtype, "averaged part dtype mismatch"
assert self.peer_modes[self.endpoint] != Mode.AUX, "You hear a reasonable explanation on why this behavior is wrong"
logger.debug(f"{self} - receiving averaged tensor part from {source}")
self.averaged_tensor_parts[source] = averaged_part
if len(self.averaged_tensor_parts) == len(self.local_tensor_parts):
Expand Down Expand Up @@ -133,9 +155,9 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
chunk_size_bytes: int, part_sizes: Tuple[int, ...], weights: Tuple[float, ...],
gathered: Dict[Endpoint, Any], return_deltas: bool = False):
gathered: Dict[Endpoint, Any], return_deltas: bool = False, **kwargs):
super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes,
ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas, **kwargs)
self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
self.peer_weights = dict(zip(self.ordered_group_endpoints, weights))

Expand All @@ -144,6 +166,7 @@ def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAver

async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
""" Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
assert self.peer_modes[self.endpoint] != Mode.AUX, "aux peers are disallowed from sending tensors"
if peer_endpoint == self.endpoint:
return await self.accumulate_part(self.endpoint, local_part, weight=self.peer_weights[self.endpoint])
serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
Expand Down Expand Up @@ -182,9 +205,14 @@ async def run(self) -> Sequence[torch.Tensor]:
send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
"""
try:
await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer])
for i, peer in enumerate(self.ordered_group_endpoints)
if peer not in self.client_mode_endpoints))
if self.peer_modes[self.endpoint] != Mode.AUX:
print(f'{self.endpoint} - SENDING STUFF, {self.peer_modes}')
await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer])
for i, peer in enumerate(self.ordered_group_endpoints)
if self.peer_modes[peer] != Mode.CLIENT))
else:
print(f'{self.endpoint} - NOT SENDING STUFF {self.peer_modes}')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove debug print or convert to a logger call (IMO the first is preferable)


return await self
except BaseException as e:
code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
Expand Down
44 changes: 29 additions & 15 deletions tests/test_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import pytest
import hivemind
from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts, Mode
from hivemind.client.averaging.load_balancing import load_balance_peers
from hivemind.client.averaging.key_manager import GroupKeyManager
from hivemind.utils import Endpoint
Expand Down Expand Up @@ -41,26 +41,26 @@ async def test_key_manager():
assert len(q5) == 0


@pytest.mark.forked
@pytest.mark.parametrize("n_client_mode_peers", [0, 2])
def test_allreduce_once(n_client_mode_peers):
def _test_allreduce_once(n_clients, n_aux):
dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')

n_peers = 4
should_listen = [False] * n_client_mode_peers + [True] * (n_peers - n_client_mode_peers)
random.shuffle(should_listen)

modes = [Mode.CLIENT] * n_clients + [Mode.AUX] * n_aux + [Mode.NODE] * (n_peers - n_clients - n_aux)
random.shuffle(modes)
tensors1 = [torch.randn(123), torch.zeros(3)]
tensors2 = [torch.rand(123), torch.ones(3)]
tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]

reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))]
peer_tensors = [tensors1, tensors2, tensors3, tensors4]

reference = [sum(tensors[i] for tensors, mode in zip(peer_tensors, modes)
if mode != Mode.AUX) / max(1, n_peers - n_aux) for i in range(len(tensors1))]

averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
prefix='mygroup', listen=listen, listen_on='127.0.0.1:*',
start=True)
for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)]
prefix='mygroup', listen=mode != Mode.CLIENT, listen_on='127.0.0.1:*',
auxiliary=mode == Mode.AUX, start=True)
for tensors, mode in zip(peer_tensors, modes)]

futures = []
for averager in averagers:
Expand All @@ -71,15 +71,29 @@ def test_allreduce_once(n_client_mode_peers):
assert averager.endpoint in result

for averager in averagers:
with averager.get_tensors() as averaged_tensors:
for ref, our in zip(reference, averaged_tensors):
assert torch.allclose(ref, our, atol=1e-6)
if averager.mode != Mode.AUX:
with averager.get_tensors() as averaged_tensors:
for ref, our in zip(reference, averaged_tensors):
assert torch.allclose(ref, our, atol=1e-6)

for averager in averagers:
averager.shutdown()
dht.shutdown()


@pytest.mark.forked
@pytest.mark.parametrize("n_clients", [0, 1, 2])
@pytest.mark.parametrize("n_aux", [0, 1, 2])
def test_allreduce_once(n_clients, n_aux):
_test_allreduce_once(n_clients, n_aux)


@pytest.mark.forked
@pytest.mark.parametrize("n_clients, n_aux", [(0, 4), (1, 3), (0, 3)])
def test_allreduce_once_edge_cases(n_clients, n_aux):
_test_allreduce_once(n_clients, n_aux)


@pytest.mark.forked
def test_allreduce_weighted(n_client_mode_peers: int = 2):
dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
Expand Down