Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disentangle DecentralizedAverager components, add weights #217

Merged
merged 14 commits into from
Apr 10, 2021
2 changes: 1 addition & 1 deletion docs/modules/client.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
.. autoclass:: DecentralizedAverager
:members:
:member-order: bysource
:exclude-members: get_tensors, update_tensors, rpc_join_group, rpc_aggregate_part
:exclude-members: get_tensors, get_tensors_async, update_tensors, rpc_join_group, rpc_aggregate_part
97 changes: 70 additions & 27 deletions hivemind/client/averaging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,26 @@
import uuid
import weakref
from concurrent.futures.thread import ThreadPoolExecutor
from dataclasses import asdict
from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator

import grpc
from grpc._cython.cygrpc import InternalError
import torch
import numpy as np

import hivemind
from hivemind.dht import DHT, DHTID
from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts
from hivemind.client.averaging.load_balancing import load_balance_peers
from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
from hivemind.client.averaging.group_info import GroupInfo
from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, \
serialize_torch_tensor, deserialize_torch_tensor, split_for_streaming, combine_from_streaming
from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
from hivemind.utils import Endpoint, Port, MPFuture, get_logger
from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescriptor

# flavour types
StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
Expand Down Expand Up @@ -85,7 +88,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
_pending_group_assembled: asyncio.Event
serializer = MSGPackSerializer

def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *, start: bool,
def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: bool,
prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
averaging_expiration: float = 15, request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
Expand All @@ -112,12 +115,15 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.D
for tensor in self._averaged_tensors:
assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
tensor.share_memory_()
self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
self.schema_hash = compute_schema_hash(self._averaged_tensors)
self._throughput = throughput

self.matchmaking_kwargs = dict(
prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout,
chunk_size_bytes=chunk_size_bytes, compression_type=compression_type,
throughput=throughput, min_vector_size=min_vector_size)
min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout)
self.allreduce_kwargs = dict(compression_type=compression_type, chunk_size_bytes=chunk_size_bytes,
min_vector_size=min_vector_size)
self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce

Expand Down Expand Up @@ -170,8 +176,8 @@ async def _run():
else:
logger.info(f"The averager running in an experimental client mode, please report any bugs.")

self._matchmaking = Matchmaking(self.endpoint, self._averaged_tensors, self.dht, **self.matchmaking_kwargs,
client_mode=not self.listen, return_deltas=True)
self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs,
client_mode=not self.listen)
if self.listen:
asyncio.create_task(self._declare_for_download_periodically())

Expand Down Expand Up @@ -207,55 +213,58 @@ def __del__(self):
if self._parent_pid != os.getpid() or self.is_alive():
self.shutdown()

def step(self, gather: Optional[DataForGather] = None, allow_retries: bool = True, timeout: Optional[float] = None,
wait=True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
def step(self, gather: Optional[DataForGather] = None, weight: float = 1.0, timeout: Optional[float] = None,
allow_retries: bool = True, wait: bool = True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
"""
Set up the averager to look for a group and run one round of averaging, return True on success, False on failure

:param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
within the specified timeout
:param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
(this operation is known as all-gather). The gathered data will be available as the output of this function.
:param weight: averaging weight for this peer, int or float, must be strictly positive
:param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
within the specified timeout
:param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
:param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
:returns: on success, update averaged_tensors and return group info; on failure, return None
"""
assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
future, _future = MPFuture.make_pair()
gather_binary = self.serializer.dumps(gather) # serialize here to avoid loading modules in the averager process
self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary,
self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
allow_retries=allow_retries, timeout=timeout)))
return future.result() if wait else future

async def _step(self, *, future: MPFuture, gather_binary: bytes, allow_retries: bool, timeout: Optional[float]):
async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
allow_retries: bool, timeout: Optional[float]):
loop = asyncio.get_event_loop()
start_time = get_dht_time()
group_id = None

while not future.done():
try:
self._pending_group_assembled.clear()
allreduce_group = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=gather_binary)
if allreduce_group is None:
data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
if group_info is None:
raise AllreduceException("Averaging step failed: could not find a group.")

group_id = allreduce_group.group_id
self._running_groups[group_id] = allreduce_group
group_id = group_info.group_id
allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
self._running_groups[group_id] = allreduce_runner
self._pending_group_assembled.set()
await asyncio.wait_for(allreduce_group.run(), self._allreduce_timeout)
await loop.run_in_executor(None, self.update_tensors, allreduce_group)
await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
await loop.run_in_executor(None, self.update_tensors, allreduce_runner)

# averaging is finished, exit the loop
gathered_items = map(self.serializer.loads, allreduce_group.gathered)
gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
future.set_result(gathered_data_by_peer)
future.set_result(allreduce_runner.gathered)

except (AllreduceException, MatchmakingException, asyncio.InvalidStateError,
grpc.RpcError, grpc.aio.AioRpcError, InternalError) as e:
except (AllreduceException, MatchmakingException, AssertionError,
asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError, InternalError) as e:
time_elapsed = get_dht_time() - start_time
if not allow_retries or (timeout is not None and timeout < time_elapsed):
logger.warning(f"Averager caught {e}")
future.set_result(None)
else:
logger.debug(f"caught {e}, retrying")
logger.warning(f"Averager caught {e}, retrying")

except Exception as e:
future.set_exception(e)
Expand All @@ -264,6 +273,23 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, allow_retries:
_ = self._running_groups.pop(group_id, None)
self._pending_group_assembled.set()

async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
""" Use a group description found by Matchmaking to form AllreduceRunner """
try:
weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
print(weights)

# compute optimal part sizes from peer throughputs
incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]
part_sizes = load_balance_peers(self.total_size, incoming_throughputs, min_size=min_vector_size)
async with self.get_tensors_async() as averaged_tensors:
return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint,
ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes,
weights=weights, gathered=user_gathered, return_deltas=True, **kwargs)
except Exception as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's best to catch a more specific exception

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code was intentionally written to catch all kinds of fails
including, but not limited to

  • malformed metadata
  • serialization errors
  • negative throughputs
  • unsolvable load_balancing
  • your endpoint is not in group
  • some endpoint is present multiple times
  • AssertionError

[Resolution: let's keep it as is for now]

raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {group_info}")

def update_tensors(self, allreduce_group: AllReduceRunner):
"""
a private (extendable) method that applies changes from a finished allreduce to local tensors
Expand All @@ -288,6 +314,15 @@ def get_tensors(self) -> Sequence[torch.Tensor]:
yield self._averaged_tensors
self.last_updated = get_dht_time()

@contextlib.asynccontextmanager
async def get_tensors_async(self) -> Sequence[torch.Tensor]:
""" Like get_tensors, but uses an asynchronous contextmanager """
try:
await asyncio.get_event_loop().run_in_executor(None, self.lock_averaged_tensors.acquire)
yield self._averaged_tensors
finally:
self.lock_averaged_tensors.release()

async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
""" accept or reject a join request from another averager; if accepted, run him through allreduce steps """
Expand Down Expand Up @@ -478,3 +513,11 @@ def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp.
future.set_exception(e)
logger.warning(e)
continue


def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
""" A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values """
schema_dicts = [{field_name: str(field_value)
for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
for tensor in tensors]
return DHTID.generate(source=schema_dicts).to_bytes()
32 changes: 19 additions & 13 deletions hivemind/client/averaging/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoi
assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
self.group_id, self.endpoint = group_id, endpoint
self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes
self.client_mode_endpoints = {endpoint for endpoint, size in zip(self.ordered_group_endpoints, part_sizes) if size == 0}
self.client_mode_endpoints = {endpoint for endpoint, part_size in zip(self.ordered_group_endpoints, part_sizes)
if part_size == 0}
self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes)))
self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
self.return_deltas = return_deltas

self.accumulator = self.local_tensor_parts[self.endpoint].clone() # sum inputs from peers to this tensor
self.accumulated_from: Set[Endpoint] = {self.endpoint} # peers that we have accumulated our part from
self.accumulator = torch.zeros_like(self.local_tensor_parts[self.endpoint])
self.denominator = 0.0 # number of peers added to accumulator or sum of their weights
self.accumulated_from: Set[Endpoint] = set() # peers that we have accumulated our part from
self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future() # will be set to [accumulator / group size]
self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {} # averaged chunks from all peers will be put here
self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future() # final result or exception
Expand All @@ -56,21 +58,23 @@ def __contains__(self, endpoint: Endpoint):
def group_size(self):
return len(self.ordered_group_endpoints)

async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor) -> torch.Tensor:
async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor, weight: float = 1.0) -> torch.Tensor:
""" Add vector part to accumulator, wait for all other vectors to be added, then return the average part """
assert not self.averaged_part.done(), f"already finished averaging part: {self.averaged_part}"
assert not self.future.done(), f"already finished allreduce: {self.future}"
assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
assert source not in self.accumulated_from, "duplicate source, already received that part"
assert not self.endpoint in self.client_mode_endpoints, f"{self.endpoint} is in client mode"
assert isinstance(weight, (int, float)) and weight > 0, "averaging weights must be a non-negative int/float"
logger.debug(f"{self} - accumulating tensor part from {source}")

self.accumulator.add_(remote_part)
self.accumulator.add_(remote_part, alpha=weight)
self.denominator += weight
self.accumulated_from.add(source)

assert len(self.accumulated_from) <= self.group_size
if len(self.accumulated_from) == len(self.local_tensor_parts):
average_result = self.accumulator.div_(len(self.accumulated_from))
average_result = self.accumulator.div_(self.denominator)
self.register_averaged_part(self.endpoint, average_result)
self.averaged_part.set_result(average_result)

Expand Down Expand Up @@ -127,19 +131,21 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi

def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
chunk_size_bytes: int, part_sizes: Tuple[int, ...], group_key_seed: int, gathered: Sequence[Any] = (),
return_deltas: bool = False):
chunk_size_bytes: int, part_sizes: Tuple[int, ...], weights: Tuple[float, ...],
gathered: Dict[Endpoint, Any], return_deltas: bool = False):
super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes,
ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
self.peer_weights = dict(zip(self.ordered_group_endpoints, weights))
self.averaged_part_stream: asyncio.Future[Tuple[runtime_pb2.Tensor, ...]] = asyncio.Future()
self.group_key_seed = group_key_seed

def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)

async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
""" Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
if peer_endpoint == self.endpoint:
return await self.accumulate_part(self.endpoint, local_part, weight=self.peer_weights[self.endpoint])
serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
chunks = split_for_streaming(serialized_tensor_part, self.chunk_size_bytes)

Expand Down Expand Up @@ -178,14 +184,14 @@ async def run(self) -> Sequence[torch.Tensor]:
try:
await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer])
for i, peer in enumerate(self.ordered_group_endpoints)
if peer != self.endpoint and self.part_sizes[i] > 0))
if peer not in self.client_mode_endpoints))
return await self
except BaseException as e:
code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
self.set_exception(e)
for peer_endpoint in self.ordered_group_endpoints:
if peer_endpoint != self.endpoint:
for i, peer_endpoint in enumerate(self.ordered_group_endpoints):
if peer_endpoint != self.endpoint and self.part_sizes[i] > 0:
asyncio.create_task(self._send_error_to_peer(peer_endpoint, code))
raise

Expand All @@ -197,7 +203,7 @@ async def accumulate_part_streaming(self, source: Endpoint, stream_messages: Ite
except RuntimeError as e:
raise AllreduceException(f"Could not deserialize tensor part from {source} for streaming {e}")

averaged_part = await self.accumulate_part(source, tensor_part)
averaged_part = await self.accumulate_part(source, tensor_part, weight=self.peer_weights[source])
if not self.averaged_part_stream.done():
serialized_tensor = serialize_torch_tensor(averaged_part, self.compression_type, allow_inplace=False)
stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes))
Expand Down
Loading