diff --git a/hivemind/averaging/allreduce.py b/hivemind/averaging/allreduce.py index 3ed9f0779..9bf40a340 100644 --- a/hivemind/averaging/allreduce.py +++ b/hivemind/averaging/allreduce.py @@ -153,13 +153,17 @@ async def _communicate_with_peer(self, peer_id: PeerID): self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part) else: - loop = asyncio.get_event_loop() code = None stream = self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index)) - async for part_index, msg in aenumerate(stream): + async for part_index, (averaged_part_delta, msg) in aenumerate( + amap_in_executor( + lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg), + stream, + max_prefetch=self.tensor_part_container.prefetch, + ) + ): if code is None: code = msg.code - averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part) self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta) if code != averaging_pb2.AVERAGED_PART: diff --git a/hivemind/averaging/partition.py b/hivemind/averaging/partition.py index e1ada3c79..577e2953d 100644 --- a/hivemind/averaging/partition.py +++ b/hivemind/averaging/partition.py @@ -33,7 +33,7 @@ def __init__( peer_fractions: Sequence[float], compression_type: Union["CompressionType", Sequence["CompressionType"]] = CompressionType.NONE, part_size_bytes: int = DEFAULT_PART_SIZE_BYTES, - prefetch: int = 1, + prefetch: int = 5, ): if not isinstance(compression_type, Sequence): compression_type = [compression_type] * len(tensors)