Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve All-Reduce fault-tolerance #423

Merged
merged 61 commits into from
Dec 14, 2021
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
eced497
proper scheduling for gradient averaging
justheuristic Dec 11, 2021
9e5561d
more agressive pre-scheduling
justheuristic Dec 11, 2021
82c9473
fault-tolerant all-reduce with next chunk timeout
justheuristic Dec 12, 2021
b1b3336
fault-tolerant all-reduce with next chunk timeout
justheuristic Dec 12, 2021
4991e71
DRY
justheuristic Dec 12, 2021
965afa8
fault-tolerant allreduce: global test 1
justheuristic Dec 12, 2021
874338e
revert peer ids
justheuristic Dec 12, 2021
6838fe1
blackisort
justheuristic Dec 12, 2021
1b0e1bd
cancel
justheuristic Dec 12, 2021
ceb2a0a
next_chunk_timeout in optimizer
justheuristic Dec 13, 2021
94972a7
black
justheuristic Dec 13, 2021
eaa1603
Update hivemind/averaging/allreduce.py
justheuristic Dec 13, 2021
6f1e22a
Update hivemind/averaging/partition.py
justheuristic Dec 13, 2021
35d8306
Update hivemind/averaging/averager.py
justheuristic Dec 13, 2021
298f437
track pending tasks
justheuristic Dec 13, 2021
df67842
review
borzunov Dec 13, 2021
a3897a8
review
borzunov Dec 13, 2021
e51ee75
review
borzunov Dec 13, 2021
2386611
review
borzunov Dec 13, 2021
8eac4a0
review
borzunov Dec 13, 2021
cc348a8
review
borzunov Dec 13, 2021
ac2c004
black it
justheuristic Dec 13, 2021
abc67e4
Update hivemind/averaging/allreduce.py
justheuristic Dec 13, 2021
92ab91e
review
borzunov Dec 13, 2021
be13c1f
Merge remote-tracking branch 'origin/fault_tolerant_allreduce' into f…
justheuristic Dec 13, 2021
aa0f2b1
review
borzunov Dec 13, 2021
539a573
Update hivemind/averaging/partition.py
justheuristic Dec 13, 2021
dfa0d1a
reblack
justheuristic Dec 13, 2021
c4f4c1e
Merge remote-tracking branch 'origin/fault_tolerant_allreduce' into f…
justheuristic Dec 13, 2021
ddcb240
review
borzunov Dec 13, 2021
4df6ea7
remove if
justheuristic Dec 13, 2021
4ce62f6
review
borzunov Dec 13, 2021
a42d20c
review
justheuristic Dec 13, 2021
53071e9
review
borzunov Dec 13, 2021
c5be676
Update hivemind/averaging/allreduce.py
justheuristic Dec 13, 2021
e3506d8
review
borzunov Dec 13, 2021
8d2427e
review
borzunov Dec 13, 2021
5adfec7
review
justheuristic Dec 13, 2021
3bb5e03
review
borzunov Dec 13, 2021
e50e459
add test case
justheuristic Dec 13, 2021
f564176
isort
justheuristic Dec 13, 2021
6e86743
test for slowness
borzunov Dec 14, 2021
12b708c
report integrity
justheuristic Dec 14, 2021
cb71c7f
report integrity
borzunov Dec 14, 2021
1d56876
nirvana tests
justheuristic Dec 14, 2021
c09ed95
nirvana tests
justheuristic Dec 14, 2021
aadf8c3
nirvana tests
justheuristic Dec 14, 2021
ddeb508
nirvana tests
justheuristic Dec 14, 2021
2b0d850
nirvana tests
justheuristic Dec 14, 2021
671ca20
nirvana tests
justheuristic Dec 14, 2021
8862a85
nirvana tests
justheuristic Dec 14, 2021
f61ba96
nirvana tests
justheuristic Dec 14, 2021
5e2f33f
nirvana test - wait for client-mode peers to begin sending data (with…
justheuristic Dec 14, 2021
53ffcca
nirvana test - wait for client-mode peers to begin sending data (with…
justheuristic Dec 14, 2021
65ae51e
nirvana test - wait for client-mode peers to begin sending data (with…
justheuristic Dec 14, 2021
91e8c3c
nirvana test - wait for client-mode peers to begin sending data (with…
justheuristic Dec 14, 2021
10121f8
nirvana test - wait for client-mode peers to begin sending data (with…
justheuristic Dec 14, 2021
3d8fb18
fix edge case where all peers are auxiliaries
justheuristic Dec 14, 2021
4b675aa
blacken
justheuristic Dec 14, 2021
b03ccd9
better loggign
justheuristic Dec 14, 2021
a803aa8
learned from 160 peers
justheuristic Dec 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 119 additions & 78 deletions hivemind/averaging/allreduce.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from enum import Enum
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Set, Tuple, Type

import torch

Expand All @@ -11,8 +11,7 @@
from hivemind.utils import get_logger
from hivemind.utils.asyncio import (
achain,
aenumerate,
afirst,
aiter_with_timeout,
amap_in_executor,
anext,
as_aiter,
Expand Down Expand Up @@ -52,6 +51,10 @@ class AllReduceRunner(ServicerBase):
(the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
:param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
:param gathered: additional user-defined data collected from this group
:param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
:param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
from previous chunk will be marked as failed and excluded from averaging. default: 2 x sender_timeout
:param kwargs: additional parameters (e.g. part_size_bytes) will be passed to TensorPartContainer
:note: Full-mode peers send and receive tensor parts concurrently, assuming a full-duplex TCP stream. In turn,
non-averaging peers receive results only after they finish sending, which helps them avoid
Expand All @@ -71,11 +74,18 @@ def __init__(
peer_fractions: Tuple[float, ...],
modes: Optional[Sequence[AveragingMode]] = None,
gathered: Optional[Dict[PeerID, Any]] = None,
sender_timeout: Optional[float] = None,
reducer_timeout: Optional[float] = None,
**kwargs,
):
self._p2p = p2p
self.peer_id = p2p.peer_id
assert self.peer_id in ordered_peer_ids, "peer_id is not a part of the group"
if reducer_timeout is not None and (sender_timeout is None or reducer_timeout <= sender_timeout):
raise ValueError(
"If reducer_timeout is enabled, sender_timeout must be shorter than reducer_timeout. "
"Otherwise, there is a chance that reducers will be banned while they await senders."
)

if not issubclass(servicer_type, ServicerBase):
raise TypeError("`servicer_type` is expected to be a ServicerBase subclass")
Expand All @@ -102,6 +112,11 @@ def __init__(
if mode != AveragingMode.AUX:
self.sender_peer_ids.append(peer_id)

self.sender_timeout, self.reducer_timeout = sender_timeout, reducer_timeout
self.active_senders: Set[PeerID] = {self.peer_id} # peers that began sending data via rpc_aggregate_part
self.banned_senders: Set[PeerID] = set() # peers that did not send data by next_chunk_timeout
self.banlock = asyncio.Lock()

peer_id_index = self.ordered_peer_ids.index(self.peer_id)
self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(peer_id_index)
Expand Down Expand Up @@ -132,6 +147,8 @@ def should_delay_results(self, peer_id: PeerID) -> bool:
async def run(self) -> AsyncIterator[torch.Tensor]:
"""Run all-reduce, return differences between averaged and original tensors as they are computed"""
pending_tasks = set()
if self.sender_timeout is not None:
pending_tasks.add(asyncio.create_task(self._handle_missing_senders()))
try:
if len(self.sender_peer_ids) == 0:
logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
Expand All @@ -151,11 +168,21 @@ async def run(self) -> AsyncIterator[torch.Tensor]:
self.finalize()

except BaseException as e:
self.finalize(exception=e)
for task in pending_tasks:
task.cancel()
if task.done() and not task.cancelled():
logger.debug(f"Task {task} failed with {task.exception()}", exc_info=True)
self.finalize(exception=e)
raise

async def _handle_missing_senders(self):
"""Detect senders that should have sent tensors for averaging, but did not send anything within timeout"""
assert self.sender_timeout is not None
await asyncio.sleep(self.sender_timeout)
for peer_id in self.sender_peer_ids:
if peer_id not in self.active_senders and peer_id not in self.banned_senders:
await self._ban_sender(peer_id)

async def _communicate_with_peer(self, peer_id: PeerID):
"""Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors"""
peer_index = self.ordered_peer_ids.index(peer_id)
Expand All @@ -168,25 +195,38 @@ async def _communicate_with_peer(self, peer_id: PeerID):
self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)

else:
code = None
stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
async for part_index, (averaged_part_delta, msg) in aenumerate(
amap_in_executor(
lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg),
stream,
try:
done_sending = asyncio.Event()
inputs_aiter = attach_event_on_finished(self._generate_input_for_peer(peer_index), done_sending)
stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(inputs_aiter)

if self.should_delay_results(self.peer_id):
await done_sending.wait()

part_index = 0

def _try_deserialize(msg):
if msg.code != averaging_pb2.AVERAGED_PART:
raise AllreduceException(f"{peer_id} sent {averaging_pb2.MessageCode.Name(msg.code)}")
return deserialize_torch_tensor(msg.tensor_part), msg

async for delta, msg in amap_in_executor(
_try_deserialize,
aiter_with_timeout(stream, self.reducer_timeout),
max_prefetch=self.tensor_part_container.prefetch,
)
):
if code is None:
code = msg.code
self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)

if code != averaging_pb2.AVERAGED_PART:
raise AllreduceException(
f"peer {peer_id} returned {averaging_pb2.MessageCode.Name(code)} "
f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
f", allreduce failed"
)
):
self.tensor_part_container.register_processed_part(peer_index, part_index, delta)
part_index += 1

if part_index != self.tensor_part_container.num_parts_by_peer[peer_index]:
raise AllreduceException(
f"peer {peer_id} sent {part_index} parts, but we expected "
f"{self.tensor_part_container.num_parts_by_peer[peer_index]}"
)
except BaseException as e:
logger.warning(f"Caught {repr(e)} when communicating to {peer_id}")
self.tensor_part_container.register_failed_reducer(peer_index)
raise

async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
Expand All @@ -204,29 +244,34 @@ async def rpc_aggregate_part(
self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
) -> AsyncIterator[averaging_pb2.AveragingData]:
"""a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
request: averaging_pb2.AveragingData = await anext(stream)
reason_to_reject = self._check_reasons_to_reject(request)
if reason_to_reject:
yield reason_to_reject
return

elif request.code == averaging_pb2.PART_FOR_AVERAGING:
try:
sender_index = self.sender_peer_ids.index(context.remote_id)
sender_index = self.sender_peer_ids.index(context.remote_id)

try:
request: averaging_pb2.AveragingData = await asyncio.wait_for(anext(stream), self.sender_timeout)
reason_to_reject = self._check_reasons_to_reject(request, context)
if reason_to_reject:
yield reason_to_reject
return

elif request.code == averaging_pb2.PART_FOR_AVERAGING:
stream = aiter_with_timeout(achain(as_aiter(request), stream), self.sender_timeout)
self.active_senders.add(context.remote_id)
if not self.should_delay_results(context.remote_id):
async for msg in self._accumulate_parts_streaming(achain(as_aiter(request), stream), sender_index):
async for msg in self._accumulate_parts_streaming(stream, sender_index):
yield msg

else:
done_receiving = asyncio.Event()
delayed_results = asyncio.Queue()

async def _accumulate_parts():
inputs_aiter = attach_event_on_finished(achain(as_aiter(request), stream), done_receiving)
async for msg in self._accumulate_parts_streaming(inputs_aiter, sender_index):
delayed_results.put_nowait(msg)
delayed_results.put_nowait(None)
try:
async for msg in self._accumulate_parts_streaming(
attach_event_on_finished(stream, done_receiving), sender_index
):
delayed_results.put_nowait(msg)
finally:
delayed_results.put_nowait(None)

accumulate_task = asyncio.create_task(_accumulate_parts())

Expand All @@ -239,63 +284,61 @@ async def _accumulate_parts():
yield next_result
await accumulate_task

except Exception as e:
self.finalize(exception=e)
else:
yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
else:
error_code = averaging_pb2.MessageCode.Name(request.code)
logger.debug(f"{self} - peer {context.remote_id} sent {error_code}, allreduce cannot continue")
self.finalize(exception=AllreduceException(f"Peer {context.remote_id} sent {error_code}"))
yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
raise AllreduceException(f"{context.remote_id} sent {averaging_pb2.MessageCode.Name(request.code)}")

def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
except BaseException as e:
await self._ban_sender(context.remote_id)
if isinstance(e, Exception):
logger.warning(f"Caught {repr(e)} when communicating with {context.remote_id}")
yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
else:
raise # CancelledError, StopIteration and similar

async def _ban_sender(self, peer_id: PeerID):
async with self.banlock:
if peer_id not in self.banned_senders:
self.banned_senders.add(peer_id)
self.tensor_part_reducer.on_sender_failed(self.sender_peer_ids.index(peer_id))

def _check_reasons_to_reject(
self, request: averaging_pb2.AveragingData, context: P2PContext
) -> Optional[averaging_pb2.AveragingData]:
if request.group_id != self.group_id:
return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
elif self._future.cancelled():
return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
elif self._future.done():
return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
elif context.remote_id not in self.sender_peer_ids:
return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)

async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
loop = asyncio.get_event_loop()
async for part_index, (tensor_part, weight, part_compression) in aenumerate(
amap_in_executor(
part_index = 0
try:
loop = asyncio.get_event_loop()
async for tensor_part, weight, part_compression in amap_in_executor(
lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.weight, msg.tensor_part.compression),
stream,
max_prefetch=self.tensor_part_container.prefetch,
)
):
averaged_part = await self.tensor_part_reducer.accumulate_part(
sender_index, part_index, tensor_part, weight=weight
)

serialized_delta = await loop.run_in_executor(
None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
)
yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
):
averaged_part = await self.tensor_part_reducer.accumulate_part(
sender_index, part_index, tensor_part, weight=weight
)
part_index += 1

async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
try:
error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
except Exception as e:
logger.debug(f"Caught {e} when sending error {averaging_pb2.MessageCode.Name(code)} to {peer_id}")
serialized_delta = await loop.run_in_executor(
None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
)
yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
finally:
if part_index != self.tensor_part_reducer.num_parts:
await self._ban_sender(self.sender_peer_ids[sender_index])

def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
"""finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
pending_tasks = set()
if cancel or exception:
# propagate error to peers
if cancel or isinstance(exception, asyncio.CancelledError):
code = averaging_pb2.CANCELLED
else:
code = averaging_pb2.INTERNAL_ERROR
logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
for peer_id, mode in zip(self.ordered_peer_ids, self.modes):
if peer_id != self.peer_id and mode != AveragingMode.CLIENT:
pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_id, code)))

if not self._future.done():
if cancel:
logger.debug(f"{self} - cancelled")
Expand All @@ -308,7 +351,5 @@ def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] =
self._future.set_result(None)
self.tensor_part_container.finalize()
self.tensor_part_reducer.finalize()
return pending_tasks
else:
logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}")
return pending_tasks
logger.debug(f"{self} - attempted to finalize allreduce that is already finished: {self._future}")
Loading