diff --git a/hivemind/averaging/allreduce.py b/hivemind/averaging/allreduce.py index 0bcd016d7..003138b58 100644 --- a/hivemind/averaging/allreduce.py +++ b/hivemind/averaging/allreduce.py @@ -37,13 +37,12 @@ class AllReduceRunner(ServicerBase): :param prefix: namespace for servicer's RPCs (typically, equal to prefix for group keys) :param group_id: unique identifier of this specific all-reduce run :param tensors: local tensors that should be averaged with groupmates - :param tensors: local tensors that should be averaged with groupmates + :param weight: scalar weight of this peer's tensors in the average (doesn't need to sum up to 1) :param peer_id: your peer_id, must be included in ordered_peer_ids :param ordered_peer_ids: group peer_ids ordered s.t. i-th peer_id is responsible for averaging i-th part :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average (the actual number of values by peer will be nearly proportional, but there are no exact guarantees) :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary) - :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers) :param gathered: additional user-defined data collected from this group :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer """ @@ -56,9 +55,9 @@ def __init__( prefix: Optional[str], group_id: GroupID, tensors: Sequence[torch.Tensor], + weight: Optional[float] = None, ordered_peer_ids: Sequence[PeerID], peer_fractions: Tuple[float, ...], - weights: Optional[Sequence[float]] = None, modes: Optional[Sequence[AveragingMode]] = None, gathered: Optional[Dict[PeerID, Any]] = None, **kwargs, @@ -73,23 +72,24 @@ def __init__( self._prefix = prefix modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions) - weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes) - assert len(weights) == len(modes) == len(ordered_peer_ids), "lists have inconsistent length" + assert len(modes) == len(ordered_peer_ids), "lists have inconsistent length" assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers" - for mode, frac, weight in zip(modes, peer_fractions, weights): + for mode, frac in zip(modes, peer_fractions): assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction" - assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight" self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered + if weight is None: + weight = float(modes[self.ordered_peer_ids.index(self.peer_id)] != AveragingMode.AUX) + self.weight = weight + self._future = asyncio.Future() - self.sender_peer_ids, self.sender_weights = [], [] - for peer_id, weight, mode in zip(self.ordered_peer_ids, weights, modes): + self.sender_peer_ids = [] + for peer_id, mode in zip(self.ordered_peer_ids, modes): if mode != AveragingMode.AUX: self.sender_peer_ids.append(peer_id) - self.sender_weights.append(weight) peer_id_index = self.ordered_peer_ids.index(self.peer_id) self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs) @@ -97,7 +97,6 @@ def __init__( self.tensor_part_reducer = TensorPartReducer( tuple(part.shape for part in self.parts_for_local_averaging), len(self.sender_peer_ids), - self.sender_weights, ) def __repr__(self): @@ -149,7 +148,9 @@ async def _communicate_with_peer(self, peer_id: PeerID): if peer_id == self.peer_id: sender_index = self.sender_peer_ids.index(peer_id) for part_index, tensor_part in enumerate(self.parts_for_local_averaging): - averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part) + averaged_part = await self.tensor_part_reducer.accumulate_part( + sender_index, part_index, tensor_part, weight=self.weight + ) self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part) else: @@ -180,9 +181,10 @@ async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[avera code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id, tensor_part=first_part, + weight=self.weight, ) async for part in parts_aiter: - yield averaging_pb2.AveragingData(tensor_part=part) + yield averaging_pb2.AveragingData(tensor_part=part, weight=self.weight) async def rpc_aggregate_part( self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext @@ -219,14 +221,16 @@ def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Opti async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int): loop = asyncio.get_event_loop() - async for part_index, (tensor_part, part_compression) in aenumerate( + async for part_index, (tensor_part, weight, part_compression) in aenumerate( amap_in_executor( - lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.tensor_part.compression), + lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.weight, msg.tensor_part.compression), stream, max_prefetch=self.tensor_part_container.prefetch, ) ): - averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part) + averaged_part = await self.tensor_part_reducer.accumulate_part( + sender_index, part_index, tensor_part, weight=weight + ) serialized_delta = await loop.run_in_executor( None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression) diff --git a/hivemind/averaging/averager.py b/hivemind/averaging/averager.py index 108ee3f7a..fc8018f30 100644 --- a/hivemind/averaging/averager.py +++ b/hivemind/averaging/averager.py @@ -367,7 +367,7 @@ async def _step( while not future.done(): try: self._pending_group_assembled.clear() - data_for_gather = self.serializer.dumps([weight, self.bandwidth, self.mode.value, gather_binary]) + data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, gather_binary]) group_info = await self._matchmaking.look_for_group( timeout=timeout, data_for_gather=data_for_gather ) @@ -376,7 +376,9 @@ async def _step( future.set_result( await asyncio.wait_for( - self._run_allreduce(group_info, tensor_infos=self.tensor_infos, **self.allreduce_kwargs), + self._run_allreduce( + group_info, tensor_infos=self.tensor_infos, weight=weight, **self.allreduce_kwargs + ), timeout=self._allreduce_timeout, ) ) @@ -414,7 +416,7 @@ async def _step( async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData: """Run All-Reduce in a given group and update tensors in place, return gathered metadata""" try: - weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered)) + bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered)) user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered))) modes = tuple(map(AveragingMode, mode_ids)) @@ -435,7 +437,6 @@ async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kw tensors=local_tensors, ordered_peer_ids=group_info.peer_ids, peer_fractions=peer_fractions, - weights=weights, gathered=user_gathered, modes=modes, **kwargs, diff --git a/hivemind/averaging/partition.py b/hivemind/averaging/partition.py index be676e635..dc083b76f 100644 --- a/hivemind/averaging/partition.py +++ b/hivemind/averaging/partition.py @@ -167,15 +167,11 @@ class TensorPartReducer: Auxiliary data structure responsible for running asynchronous all-reduce :param part_shapes: a sequence of shapes of torch tensors that will be averaged by this reducer :param num_senders: total number of peers in a given all-reduce group that will send gradients - :param weights: relative importance of each sender, used for weighted average (default = equal weights) :note: even if local peer is not sending data, local parts will be used for shape information """ - def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int, weights: Optional[Sequence[float]] = None): + def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int): self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes) - self.weights = tuple(weights or (1 for _ in range(num_senders))) - assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders" - assert all(isinstance(weight, (int, float)) for weight in self.weights) self.current_part_index = -1 # index in local_parts of the part that should be loaded next self.current_part_accumulated_from = 0 # number of peers from which the current part was accumulated self.accumulator = None # this will contain the sum of current tensor part from group peers @@ -197,7 +193,9 @@ def reset_accumulators(self): self.accumulator = torch.zeros(self.part_shapes[self.current_part_index]) self.denominator = 0.0 - async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor: + async def accumulate_part( + self, sender_index: int, part_index: int, tensor_part: torch.Tensor, weight: float = 1.0 + ) -> torch.Tensor: """Add vector part to accumulator, wait for all other vectors to be added, then return the average part""" assert 0 <= sender_index < self.num_senders, "invalid sender index" assert 0 <= part_index < self.num_parts, "invalid part index" @@ -211,9 +209,9 @@ async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: current_part_future = self.current_part_future - self.accumulator.add_(tensor_part, alpha=self.weights[sender_index]) - self.denominator += self.weights[sender_index] + self.accumulator.add_(tensor_part, alpha=weight) self.current_part_accumulated_from += 1 + self.denominator += weight assert self.current_part_accumulated_from <= self.num_senders if self.current_part_accumulated_from == self.num_senders: diff --git a/hivemind/proto/averaging.proto b/hivemind/proto/averaging.proto index 666da3cdc..064fa80af 100644 --- a/hivemind/proto/averaging.proto +++ b/hivemind/proto/averaging.proto @@ -45,7 +45,7 @@ message AveragingData { bytes group_id = 2; // a unique group identifier, same as in MessageFromLeader bytes peer_id = 3; // sender's rpc peer_id, used for coordination Tensor tensor_part = 4; // either peer's local tensor part (rpc input) or group average of this part (rpc output) - bytes metadata = 5; // reserved user-extendable metadata + double weight = 5; // peers will be averaged in proportion to these weights } message DownloadRequest {} diff --git a/tests/test_allreduce.py b/tests/test_allreduce.py index 66d82678f..391412ceb 100644 --- a/tests/test_allreduce.py +++ b/tests/test_allreduce.py @@ -187,7 +187,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big") allreduce_protocols = [] - for p2p in p2ps: + for i, p2p in enumerate(p2ps): allreduce_protocol = AllReduceRunner( p2p=p2p, servicer_type=AllReduceRunner, @@ -197,7 +197,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, ordered_peer_ids=peers, peer_fractions=peer_fractions, modes=peer_modes, - weights=averaging_weights, + weight=averaging_weights[i], part_size_bytes=part_size_bytes, ) await allreduce_protocol.add_p2p_handlers(p2p)