Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support auxiliary participants in AllReduceProtocol #260

Merged
merged 25 commits into from
May 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 55 additions & 13 deletions hivemind/client/averaging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np

from hivemind.dht import DHT, DHTID
from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts
from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
from hivemind.client.averaging.load_balancing import load_balance_peers
from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
from hivemind.client.averaging.group_info import GroupInfo
Expand Down Expand Up @@ -71,6 +71,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
:param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
:param kwargs: extra parameters forwarded to grpc.aio.server
:param auxiliary: if this flag is specified, averager.step will only assist others without sending
local tensors for averaging
:param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
with averager.allow_state_sharing = True / False

Example:

Expand All @@ -94,6 +98,7 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start:
allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
throughput: Optional[float] = None, min_vector_size: int = 0,
auxiliary: bool = False, allow_state_sharing: Optional[bool] = None,
listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
assert '.' not in prefix, "group prefix must be a string without trailing '.'"
Expand All @@ -102,10 +107,18 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start:
if not is_power_of_two(target_group_size):
logger.warning("It is recommended to set target_group_size to a power of 2.")
assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
assert listen or not auxiliary, "auxiliary peers must accept incoming connections"

super().__init__()
self.dht = dht
self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs
if not self.listen:
self.mode = AveragingMode.CLIENT
elif auxiliary:
self.mode = AveragingMode.AUX
else:
self.mode = AveragingMode.NODE

self.channel_options = channel_options
self.daemon = daemon

Expand All @@ -129,6 +142,10 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start:

self._pipe, self.pipe = mp.Pipe(duplex=True) # a control pipe used to communicate with a background process
self._port = mp.Value(ctypes.c_uint32, 0) # assigned when averager starts, accessible via self.port

self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
self.allow_state_sharing = (listen and not auxiliary) if allow_state_sharing is None else allow_state_sharing

self._averager_endpoint: Optional[Endpoint] = None
if not self.listen:
self._averager_endpoint = f'client::{uuid.uuid4()}'
Expand All @@ -146,6 +163,18 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start:
def port(self) -> Optional[Port]:
return self._port.value if self._port.value != 0 else None

@property
def allow_state_sharing(self) -> bool:
""" if set to True, other peers can download this peer's state """
return bool(self._allow_state_sharing.value)

@allow_state_sharing.setter
def allow_state_sharing(self, value: bool):
if value is True and not self.listen:
logger.warning("Cannot allow state sharing: averager in client mode (listen=False) cannot share its state.")
else:
self._allow_state_sharing.value = value

@property
def endpoint(self) -> Optional[Endpoint]:
if self.listen and self._averager_endpoint is None:
Expand Down Expand Up @@ -236,7 +265,11 @@ def step(self, gather: Optional[DataForGather] = None, weight: float = 1.0, time
:param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
:returns: on success, update averaged_tensors and return group info; on failure, return None
"""
assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
if self.mode == AveragingMode.AUX and weight != 1:
logger.warning("Averager is running in auxiliary mode, weight is unused.")
else:
assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"

future, _future = MPFuture.make_pair()
gather_binary = self.serializer.dumps(gather) # serialize here to avoid loading modules in the averager process
self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
Expand All @@ -253,7 +286,7 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
while not future.done():
try:
self._pending_group_assembled.clear()
data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
data_for_gather = self.serializer.dumps([weight, self._throughput, self.mode.value, gather_binary])
group_info = await self._matchmaking.look_for_group(timeout=timeout,
data_for_gather=data_for_gather)
if group_info is None:
Expand All @@ -263,7 +296,8 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
self._running_groups[group_id] = allreduce_runner
self._pending_group_assembled.set()
await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
if self.mode != AveragingMode.AUX:
await loop.run_in_executor(None, self.update_tensors, allreduce_runner)

# averaging is finished, exit the loop
future.set_result(allreduce_runner.gathered)
Expand Down Expand Up @@ -293,19 +327,19 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
""" Use a group description found by Matchmaking to form AllreduceRunner """
try:
weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
weights, throughputs, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))

# compute optimal part sizes from peer throughputs
incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]
modes = tuple(map(AveragingMode, mode_ids))
incoming_throughputs = [thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(throughputs, modes)] # TODO: replace with proper load balancing
part_sizes = await asyncio.get_event_loop().run_in_executor(
None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
async with self.get_tensors_async() as averaged_tensors:
return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint,
ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes,
weights=weights, gathered=user_gathered, return_deltas=True, **kwargs)
weights=weights, gathered=user_gathered, return_deltas=True, modes=modes, **kwargs)
except Exception as e:
raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {group_info}")
raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {weights, throughputs, modes, user_gathered}")

def update_tensors(self, allreduce_group: AllReduceRunner):
"""
Expand Down Expand Up @@ -366,10 +400,11 @@ async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.Averaging
async def _declare_for_download_periodically(self):
download_key = f'{self._matchmaking.group_key_manager.prefix}.all_averagers'
while True:
asyncio.create_task(asyncio.wait_for(self.dht.store(
download_key, subkey=self.endpoint, value=self.last_updated,
expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
timeout=self._matchmaking.averaging_expiration))
if self.allow_state_sharing:
asyncio.create_task(asyncio.wait_for(self.dht.store(
download_key, subkey=self.endpoint, value=self.last_updated,
expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
timeout=self._matchmaking.averaging_expiration))
await asyncio.sleep(self._matchmaking.averaging_expiration)

async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
Expand All @@ -381,6 +416,8 @@ async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, conte
- serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
- tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
"""
if not self.allow_state_sharing:
return # deny request and direct peer to the next prospective averager
chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
metadata, tensors = await self._get_current_state_from_host_process()

Expand Down Expand Up @@ -452,6 +489,11 @@ async def _load_state_from_peers(self, future: MPFuture):
current_tensor_parts.append(message.tensor_part)
if current_tensor_parts:
tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))

if not metadata:
logger.debug(f"Peer {peer} did not send its state.")
continue

logger.info(f"Finished downloading state from {peer}")
future.set_result((metadata, tensors))
self.last_updated = get_dht_time()
Expand Down
56 changes: 40 additions & 16 deletions hivemind/client/averaging/allreduce.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any
from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any, Optional
from enum import Enum

import grpc
import torch
Expand All @@ -14,6 +15,12 @@
logger = get_logger(__name__)


class AveragingMode(Enum):
NODE = 0
CLIENT = 1
AUX = 2


class AllReduceProtocol:
"""
An internal class that runs butterfly AllReduce in a predefined group of averagers
Expand All @@ -27,12 +34,16 @@ class AllReduceProtocol:
"""

def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False):
ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False,
modes: Optional[Sequence[AveragingMode]] = None):
assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
self.group_id, self.endpoint = group_id, endpoint
self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes
self.client_mode_endpoints = {endpoint for endpoint, part_size in zip(self.ordered_group_endpoints, part_sizes)
if part_size == 0}
if modes is None:
modes = [AveragingMode.CLIENT if part_size == 0 else AveragingMode.NODE for part_size in part_sizes]
assert any(mode != AveragingMode.CLIENT for mode in modes), "Cannot run allreduce without reducers."
self.peer_modes = dict(zip(ordered_group_endpoints, modes))

self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes)))
self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
self.return_deltas = return_deltas
Expand All @@ -43,8 +54,14 @@ def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoi
self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future() # will be set to [accumulator / group size]
self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {} # averaged chunks from all peers will be put here
self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future() # final result or exception
for endpoint in self.client_mode_endpoints:
self.averaged_tensor_parts[endpoint] = torch.tensor([])

self.num_senders = len([mode for mode in modes if mode != AveragingMode.AUX])

if self.num_senders == 0:
self.future.set_result(None)
for endpoint, mode in self.peer_modes.items():
if mode == AveragingMode.CLIENT:
self.averaged_tensor_parts[endpoint] = torch.tensor([])

def __repr__(self):
return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
Expand All @@ -65,20 +82,24 @@ async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor, wei
assert not self.future.done(), f"already finished allreduce: {self.future}"
assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
assert source not in self.accumulated_from, "duplicate source, already received that part"
assert not self.endpoint in self.client_mode_endpoints, f"{self.endpoint} is in client mode"
assert self.peer_modes[self.endpoint] != AveragingMode.CLIENT, f"{self.endpoint} is in AveragingMode.client mode"
assert isinstance(weight, (int, float)) and weight > 0, "averaging weights must be a non-negative int/float"
logger.debug(f"{self} - accumulating tensor part from {source}")

logger.debug(f"{self} - accumulating tensor part from {source}")
self.accumulator.add_(remote_part, alpha=weight)
self.denominator += weight
self.accumulated_from.add(source)

assert len(self.accumulated_from) <= self.group_size
if len(self.accumulated_from) == len(self.local_tensor_parts):
assert len(self.accumulated_from) <= self.num_senders
if len(self.accumulated_from) == self.num_senders:
average_result = self.accumulator.div_(self.denominator)
self.register_averaged_part(self.endpoint, average_result)
self.averaged_part.set_result(average_result)

if self.peer_modes[self.endpoint] == AveragingMode.AUX:
self.future.set_result(None) # auxiliary mode has finished averaging
else:
self.register_averaged_part(self.endpoint, average_result)

return await self.averaged_part

def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
Expand All @@ -87,6 +108,7 @@ def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
assert source not in self.averaged_tensor_parts, "already registered the average from this peer"
assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch"
assert averaged_part.dtype == self.local_tensor_parts[source].dtype, "averaged part dtype mismatch"
assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers do not have local tensors for sending"
logger.debug(f"{self} - receiving averaged tensor part from {source}")
self.averaged_tensor_parts[source] = averaged_part
if len(self.averaged_tensor_parts) == len(self.local_tensor_parts):
Expand Down Expand Up @@ -133,9 +155,9 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
chunk_size_bytes: int, part_sizes: Tuple[int, ...], weights: Tuple[float, ...],
gathered: Dict[Endpoint, Any], return_deltas: bool = False):
gathered: Dict[Endpoint, Any], return_deltas: bool = False, **kwargs):
super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes,
ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas, **kwargs)
self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
self.peer_weights = dict(zip(self.ordered_group_endpoints, weights))

Expand All @@ -144,6 +166,7 @@ def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAver

async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
""" Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers are disallowed from sending tensors"
if peer_endpoint == self.endpoint:
return await self.accumulate_part(self.endpoint, local_part, weight=self.peer_weights[self.endpoint])
serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
Expand Down Expand Up @@ -182,9 +205,10 @@ async def run(self) -> Sequence[torch.Tensor]:
send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
"""
try:
await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer])
for i, peer in enumerate(self.ordered_group_endpoints)
if peer not in self.client_mode_endpoints))
if self.peer_modes[self.endpoint] != AveragingMode.AUX:
await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer])
for i, peer in enumerate(self.ordered_group_endpoints)
if self.peer_modes[peer] != AveragingMode.CLIENT))
return await self
except BaseException as e:
code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
Expand Down
Loading