Skip to content

Commit

Permalink
Support auxiliary participants in AllReduceProtocol (#260)
Browse files Browse the repository at this point in the history
* DecentralizedAverager: support auxiliary peers that assist in allreduce without sending their own data
* implement a flag that disables state sharing in averager
* more natural parameterization of batch_size vs batch_size_per_step
* update test_allreduce_once for new aux peers
  • Loading branch information
foksly authored May 21, 2021
1 parent 8673071 commit e58f65d
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 46 deletions.
68 changes: 55 additions & 13 deletions hivemind/client/averaging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np

from hivemind.dht import DHT, DHTID
from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts
from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
from hivemind.client.averaging.load_balancing import load_balance_peers
from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
from hivemind.client.averaging.group_info import GroupInfo
Expand Down Expand Up @@ -71,6 +71,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
:param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
:param kwargs: extra parameters forwarded to grpc.aio.server
:param auxiliary: if this flag is specified, averager.step will only assist others without sending
local tensors for averaging
:param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
with averager.allow_state_sharing = True / False
Example:
Expand All @@ -94,6 +98,7 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start:
allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
throughput: Optional[float] = None, min_vector_size: int = 0,
auxiliary: bool = False, allow_state_sharing: Optional[bool] = None,
listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
assert '.' not in prefix, "group prefix must be a string without trailing '.'"
Expand All @@ -102,10 +107,18 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start:
if not is_power_of_two(target_group_size):
logger.warning("It is recommended to set target_group_size to a power of 2.")
assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
assert listen or not auxiliary, "auxiliary peers must accept incoming connections"

super().__init__()
self.dht = dht
self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs
if not self.listen:
self.mode = AveragingMode.CLIENT
elif auxiliary:
self.mode = AveragingMode.AUX
else:
self.mode = AveragingMode.NODE

self.channel_options = channel_options
self.daemon = daemon

Expand All @@ -129,6 +142,10 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start:

self._pipe, self.pipe = mp.Pipe(duplex=True) # a control pipe used to communicate with a background process
self._port = mp.Value(ctypes.c_uint32, 0) # assigned when averager starts, accessible via self.port

self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
self.allow_state_sharing = (listen and not auxiliary) if allow_state_sharing is None else allow_state_sharing

self._averager_endpoint: Optional[Endpoint] = None
if not self.listen:
self._averager_endpoint = f'client::{uuid.uuid4()}'
Expand All @@ -146,6 +163,18 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start:
def port(self) -> Optional[Port]:
return self._port.value if self._port.value != 0 else None

@property
def allow_state_sharing(self) -> bool:
""" if set to True, other peers can download this peer's state """
return bool(self._allow_state_sharing.value)

@allow_state_sharing.setter
def allow_state_sharing(self, value: bool):
if value is True and not self.listen:
logger.warning("Cannot allow state sharing: averager in client mode (listen=False) cannot share its state.")
else:
self._allow_state_sharing.value = value

@property
def endpoint(self) -> Optional[Endpoint]:
if self.listen and self._averager_endpoint is None:
Expand Down Expand Up @@ -236,7 +265,11 @@ def step(self, gather: Optional[DataForGather] = None, weight: float = 1.0, time
:param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
:returns: on success, update averaged_tensors and return group info; on failure, return None
"""
assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
if self.mode == AveragingMode.AUX and weight != 1:
logger.warning("Averager is running in auxiliary mode, weight is unused.")
else:
assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"

future, _future = MPFuture.make_pair()
gather_binary = self.serializer.dumps(gather) # serialize here to avoid loading modules in the averager process
self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
Expand All @@ -253,7 +286,7 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
while not future.done():
try:
self._pending_group_assembled.clear()
data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
data_for_gather = self.serializer.dumps([weight, self._throughput, self.mode.value, gather_binary])
group_info = await self._matchmaking.look_for_group(timeout=timeout,
data_for_gather=data_for_gather)
if group_info is None:
Expand All @@ -263,7 +296,8 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
self._running_groups[group_id] = allreduce_runner
self._pending_group_assembled.set()
await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
if self.mode != AveragingMode.AUX:
await loop.run_in_executor(None, self.update_tensors, allreduce_runner)

# averaging is finished, exit the loop
future.set_result(allreduce_runner.gathered)
Expand Down Expand Up @@ -293,19 +327,19 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
""" Use a group description found by Matchmaking to form AllreduceRunner """
try:
weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
weights, throughputs, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))

# compute optimal part sizes from peer throughputs
incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]
modes = tuple(map(AveragingMode, mode_ids))
incoming_throughputs = [thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(throughputs, modes)] # TODO: replace with proper load balancing
part_sizes = await asyncio.get_event_loop().run_in_executor(
None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
async with self.get_tensors_async() as averaged_tensors:
return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint,
ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes,
weights=weights, gathered=user_gathered, return_deltas=True, **kwargs)
weights=weights, gathered=user_gathered, return_deltas=True, modes=modes, **kwargs)
except Exception as e:
raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {group_info}")
raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {weights, throughputs, modes, user_gathered}")

def update_tensors(self, allreduce_group: AllReduceRunner):
"""
Expand Down Expand Up @@ -366,10 +400,11 @@ async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.Averaging
async def _declare_for_download_periodically(self):
download_key = f'{self._matchmaking.group_key_manager.prefix}.all_averagers'
while True:
asyncio.create_task(asyncio.wait_for(self.dht.store(
download_key, subkey=self.endpoint, value=self.last_updated,
expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
timeout=self._matchmaking.averaging_expiration))
if self.allow_state_sharing:
asyncio.create_task(asyncio.wait_for(self.dht.store(
download_key, subkey=self.endpoint, value=self.last_updated,
expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
timeout=self._matchmaking.averaging_expiration))
await asyncio.sleep(self._matchmaking.averaging_expiration)

async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
Expand All @@ -381,6 +416,8 @@ async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, conte
- serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
- tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
"""
if not self.allow_state_sharing:
return # deny request and direct peer to the next prospective averager
chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
metadata, tensors = await self._get_current_state_from_host_process()

Expand Down Expand Up @@ -452,6 +489,11 @@ async def _load_state_from_peers(self, future: MPFuture):
current_tensor_parts.append(message.tensor_part)
if current_tensor_parts:
tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))

if not metadata:
logger.debug(f"Peer {peer} did not send its state.")
continue

logger.info(f"Finished downloading state from {peer}")
future.set_result((metadata, tensors))
self.last_updated = get_dht_time()
Expand Down
56 changes: 40 additions & 16 deletions hivemind/client/averaging/allreduce.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any
from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any, Optional
from enum import Enum

import grpc
import torch
Expand All @@ -14,6 +15,12 @@
logger = get_logger(__name__)


class AveragingMode(Enum):
NODE = 0
CLIENT = 1
AUX = 2


class AllReduceProtocol:
"""
An internal class that runs butterfly AllReduce in a predefined group of averagers
Expand All @@ -27,12 +34,16 @@ class AllReduceProtocol:
"""

def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False):
ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False,
modes: Optional[Sequence[AveragingMode]] = None):
assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
self.group_id, self.endpoint = group_id, endpoint
self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes
self.client_mode_endpoints = {endpoint for endpoint, part_size in zip(self.ordered_group_endpoints, part_sizes)
if part_size == 0}
if modes is None:
modes = [AveragingMode.CLIENT if part_size == 0 else AveragingMode.NODE for part_size in part_sizes]
assert any(mode != AveragingMode.CLIENT for mode in modes), "Cannot run allreduce without reducers."
self.peer_modes = dict(zip(ordered_group_endpoints, modes))

self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes)))
self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
self.return_deltas = return_deltas
Expand All @@ -43,8 +54,14 @@ def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoi
self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future() # will be set to [accumulator / group size]
self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {} # averaged chunks from all peers will be put here
self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future() # final result or exception
for endpoint in self.client_mode_endpoints:
self.averaged_tensor_parts[endpoint] = torch.tensor([])

self.num_senders = len([mode for mode in modes if mode != AveragingMode.AUX])

if self.num_senders == 0:
self.future.set_result(None)
for endpoint, mode in self.peer_modes.items():
if mode == AveragingMode.CLIENT:
self.averaged_tensor_parts[endpoint] = torch.tensor([])

def __repr__(self):
return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
Expand All @@ -65,20 +82,24 @@ async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor, wei
assert not self.future.done(), f"already finished allreduce: {self.future}"
assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
assert source not in self.accumulated_from, "duplicate source, already received that part"
assert not self.endpoint in self.client_mode_endpoints, f"{self.endpoint} is in client mode"
assert self.peer_modes[self.endpoint] != AveragingMode.CLIENT, f"{self.endpoint} is in AveragingMode.client mode"
assert isinstance(weight, (int, float)) and weight > 0, "averaging weights must be a non-negative int/float"
logger.debug(f"{self} - accumulating tensor part from {source}")

logger.debug(f"{self} - accumulating tensor part from {source}")
self.accumulator.add_(remote_part, alpha=weight)
self.denominator += weight
self.accumulated_from.add(source)

assert len(self.accumulated_from) <= self.group_size
if len(self.accumulated_from) == len(self.local_tensor_parts):
assert len(self.accumulated_from) <= self.num_senders
if len(self.accumulated_from) == self.num_senders:
average_result = self.accumulator.div_(self.denominator)
self.register_averaged_part(self.endpoint, average_result)
self.averaged_part.set_result(average_result)

if self.peer_modes[self.endpoint] == AveragingMode.AUX:
self.future.set_result(None) # auxiliary mode has finished averaging
else:
self.register_averaged_part(self.endpoint, average_result)

return await self.averaged_part

def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
Expand All @@ -87,6 +108,7 @@ def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
assert source not in self.averaged_tensor_parts, "already registered the average from this peer"
assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch"
assert averaged_part.dtype == self.local_tensor_parts[source].dtype, "averaged part dtype mismatch"
assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers do not have local tensors for sending"
logger.debug(f"{self} - receiving averaged tensor part from {source}")
self.averaged_tensor_parts[source] = averaged_part
if len(self.averaged_tensor_parts) == len(self.local_tensor_parts):
Expand Down Expand Up @@ -133,9 +155,9 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
chunk_size_bytes: int, part_sizes: Tuple[int, ...], weights: Tuple[float, ...],
gathered: Dict[Endpoint, Any], return_deltas: bool = False):
gathered: Dict[Endpoint, Any], return_deltas: bool = False, **kwargs):
super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes,
ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas, **kwargs)
self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
self.peer_weights = dict(zip(self.ordered_group_endpoints, weights))

Expand All @@ -144,6 +166,7 @@ def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAver

async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
""" Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers are disallowed from sending tensors"
if peer_endpoint == self.endpoint:
return await self.accumulate_part(self.endpoint, local_part, weight=self.peer_weights[self.endpoint])
serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
Expand Down Expand Up @@ -182,9 +205,10 @@ async def run(self) -> Sequence[torch.Tensor]:
send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
"""
try:
await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer])
for i, peer in enumerate(self.ordered_group_endpoints)
if peer not in self.client_mode_endpoints))
if self.peer_modes[self.endpoint] != AveragingMode.AUX:
await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer])
for i, peer in enumerate(self.ordered_group_endpoints)
if self.peer_modes[peer] != AveragingMode.CLIENT))
return await self
except BaseException as e:
code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
Expand Down
Loading

0 comments on commit e58f65d

Please sign in to comment.