Skip to content

Commit

Permalink
Implement weights as part of the allreduce protocol, not matchmaking (#…
Browse files Browse the repository at this point in the history
…384)

* implement parts as part of the allreduce protocol, not matchmaking
* remove metadata field from AveragingData (unused)

Co-authored-by: Alexander Borzunov <[email protected]>
Co-authored-by: Alexander Borzunov <[email protected]>
  • Loading branch information
3 people authored Sep 24, 2021
1 parent d809e30 commit 4a9bc92
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 31 deletions.
36 changes: 20 additions & 16 deletions hivemind/averaging/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,12 @@ class AllReduceRunner(ServicerBase):
:param prefix: namespace for servicer's RPCs (typically, equal to prefix for group keys)
:param group_id: unique identifier of this specific all-reduce run
:param tensors: local tensors that should be averaged with groupmates
:param tensors: local tensors that should be averaged with groupmates
:param weight: scalar weight of this peer's tensors in the average (doesn't need to sum up to 1)
:param peer_id: your peer_id, must be included in ordered_peer_ids
:param ordered_peer_ids: group peer_ids ordered s.t. i-th peer_id is responsible for averaging i-th part
:param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
(the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
:param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
:param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
:param gathered: additional user-defined data collected from this group
:param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
"""
Expand All @@ -56,9 +55,9 @@ def __init__(
prefix: Optional[str],
group_id: GroupID,
tensors: Sequence[torch.Tensor],
weight: Optional[float] = None,
ordered_peer_ids: Sequence[PeerID],
peer_fractions: Tuple[float, ...],
weights: Optional[Sequence[float]] = None,
modes: Optional[Sequence[AveragingMode]] = None,
gathered: Optional[Dict[PeerID, Any]] = None,
**kwargs,
Expand All @@ -73,31 +72,31 @@ def __init__(
self._prefix = prefix

modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
assert len(weights) == len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
assert len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
for mode, frac, weight in zip(modes, peer_fractions, weights):
for mode, frac in zip(modes, peer_fractions):
assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"

self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered

if weight is None:
weight = float(modes[self.ordered_peer_ids.index(self.peer_id)] != AveragingMode.AUX)
self.weight = weight

self._future = asyncio.Future()

self.sender_peer_ids, self.sender_weights = [], []
for peer_id, weight, mode in zip(self.ordered_peer_ids, weights, modes):
self.sender_peer_ids = []
for peer_id, mode in zip(self.ordered_peer_ids, modes):
if mode != AveragingMode.AUX:
self.sender_peer_ids.append(peer_id)
self.sender_weights.append(weight)

peer_id_index = self.ordered_peer_ids.index(self.peer_id)
self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(peer_id_index)
self.tensor_part_reducer = TensorPartReducer(
tuple(part.shape for part in self.parts_for_local_averaging),
len(self.sender_peer_ids),
self.sender_weights,
)

def __repr__(self):
Expand Down Expand Up @@ -149,7 +148,9 @@ async def _communicate_with_peer(self, peer_id: PeerID):
if peer_id == self.peer_id:
sender_index = self.sender_peer_ids.index(peer_id)
for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
averaged_part = await self.tensor_part_reducer.accumulate_part(
sender_index, part_index, tensor_part, weight=self.weight
)
self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)

else:
Expand Down Expand Up @@ -180,9 +181,10 @@ async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[avera
code=averaging_pb2.PART_FOR_AVERAGING,
group_id=self.group_id,
tensor_part=first_part,
weight=self.weight,
)
async for part in parts_aiter:
yield averaging_pb2.AveragingData(tensor_part=part)
yield averaging_pb2.AveragingData(tensor_part=part, weight=self.weight)

async def rpc_aggregate_part(
self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
Expand Down Expand Up @@ -219,14 +221,16 @@ def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Opti

async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
loop = asyncio.get_event_loop()
async for part_index, (tensor_part, part_compression) in aenumerate(
async for part_index, (tensor_part, weight, part_compression) in aenumerate(
amap_in_executor(
lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.tensor_part.compression),
lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.weight, msg.tensor_part.compression),
stream,
max_prefetch=self.tensor_part_container.prefetch,
)
):
averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
averaged_part = await self.tensor_part_reducer.accumulate_part(
sender_index, part_index, tensor_part, weight=weight
)

serialized_delta = await loop.run_in_executor(
None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
Expand Down
9 changes: 5 additions & 4 deletions hivemind/averaging/averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ async def _step(
while not future.done():
try:
self._pending_group_assembled.clear()
data_for_gather = self.serializer.dumps([weight, self.bandwidth, self.mode.value, gather_binary])
data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, gather_binary])
group_info = await self._matchmaking.look_for_group(
timeout=timeout, data_for_gather=data_for_gather
)
Expand All @@ -376,7 +376,9 @@ async def _step(

future.set_result(
await asyncio.wait_for(
self._run_allreduce(group_info, tensor_infos=self.tensor_infos, **self.allreduce_kwargs),
self._run_allreduce(
group_info, tensor_infos=self.tensor_infos, weight=weight, **self.allreduce_kwargs
),
timeout=self._allreduce_timeout,
)
)
Expand Down Expand Up @@ -414,7 +416,7 @@ async def _step(
async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
"""Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
try:
weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
modes = tuple(map(AveragingMode, mode_ids))

Expand All @@ -435,7 +437,6 @@ async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kw
tensors=local_tensors,
ordered_peer_ids=group_info.peer_ids,
peer_fractions=peer_fractions,
weights=weights,
gathered=user_gathered,
modes=modes,
**kwargs,
Expand Down
14 changes: 6 additions & 8 deletions hivemind/averaging/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,11 @@ class TensorPartReducer:
Auxiliary data structure responsible for running asynchronous all-reduce
:param part_shapes: a sequence of shapes of torch tensors that will be averaged by this reducer
:param num_senders: total number of peers in a given all-reduce group that will send gradients
:param weights: relative importance of each sender, used for weighted average (default = equal weights)
:note: even if local peer is not sending data, local parts will be used for shape information
"""

def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int, weights: Optional[Sequence[float]] = None):
def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int):
self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
self.weights = tuple(weights or (1 for _ in range(num_senders)))
assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
assert all(isinstance(weight, (int, float)) for weight in self.weights)
self.current_part_index = -1 # index in local_parts of the part that should be loaded next
self.current_part_accumulated_from = 0 # number of peers from which the current part was accumulated
self.accumulator = None # this will contain the sum of current tensor part from group peers
Expand All @@ -197,7 +193,9 @@ def reset_accumulators(self):
self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
self.denominator = 0.0

async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
async def accumulate_part(
self, sender_index: int, part_index: int, tensor_part: torch.Tensor, weight: float = 1.0
) -> torch.Tensor:
"""Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
assert 0 <= sender_index < self.num_senders, "invalid sender index"
assert 0 <= part_index < self.num_parts, "invalid part index"
Expand All @@ -211,9 +209,9 @@ async def accumulate_part(self, sender_index: int, part_index: int, tensor_part:

current_part_future = self.current_part_future

self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
self.denominator += self.weights[sender_index]
self.accumulator.add_(tensor_part, alpha=weight)
self.current_part_accumulated_from += 1
self.denominator += weight

assert self.current_part_accumulated_from <= self.num_senders
if self.current_part_accumulated_from == self.num_senders:
Expand Down
2 changes: 1 addition & 1 deletion hivemind/proto/averaging.proto
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ message AveragingData {
bytes group_id = 2; // a unique group identifier, same as in MessageFromLeader
bytes peer_id = 3; // sender's rpc peer_id, used for coordination
Tensor tensor_part = 4; // either peer's local tensor part (rpc input) or group average of this part (rpc output)
bytes metadata = 5; // reserved user-extendable metadata
double weight = 5; // peers will be averaged in proportion to these weights
}

message DownloadRequest {}
Expand Down
4 changes: 2 additions & 2 deletions tests/test_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big")

allreduce_protocols = []
for p2p in p2ps:
for i, p2p in enumerate(p2ps):
allreduce_protocol = AllReduceRunner(
p2p=p2p,
servicer_type=AllReduceRunner,
Expand All @@ -197,7 +197,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
ordered_peer_ids=peers,
peer_fractions=peer_fractions,
modes=peer_modes,
weights=averaging_weights,
weight=averaging_weights[i],
part_size_bytes=part_size_bytes,
)
await allreduce_protocol.add_p2p_handlers(p2p)
Expand Down

0 comments on commit 4a9bc92

Please sign in to comment.