-
Notifications
You must be signed in to change notification settings - Fork 176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve All-Reduce fault-tolerance #423
Conversation
Codecov Report
@@ Coverage Diff @@
## master #423 +/- ##
==========================================
+ Coverage 83.40% 84.22% +0.81%
==========================================
Files 77 77
Lines 7809 7891 +82
==========================================
+ Hits 6513 6646 +133
+ Misses 1296 1245 -51
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! I've went through everything but tests and left my comments.
hivemind/utils/asyncio.py
Outdated
await queue.put(loop.run_in_executor(executor, func, *args)) | ||
await queue.put(None) | ||
except BaseException as e: | ||
await queue.put(e) # note: there is no chance that iterables |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfinished comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixd
hivemind/utils/asyncio.py
Outdated
task.cancel() | ||
task.cancel() | ||
if task.done() and not task.cancelled(): | ||
task.exception() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean raise task.exception()
? If yes, it does not seem necessary since await task
will already raise this exception.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I specifically mean "if task did not send result or throw exception but we died anyway, silence the "task ... was never retrieved""
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hivemind/averaging/partition.py
Outdated
self.current_senders -= 1 | ||
if self.current_part_accumulated_from == self.current_senders: | ||
self.current_part_future.set_result(self.accumulator.div_(self.denominator)) | ||
self.reset_accumulators() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please extract these two lines to a function, it is an extremely error-prone code duplication.
For instance, imagine that someone writes a CenteredClip reducer and only changes L229-230 (not L239-240): since the latter are executed only during a failure, the unit tests are likely to miss that mistake.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
hivemind/averaging/allreduce.py
Outdated
@@ -132,6 +148,8 @@ def should_delay_results(self, peer_id: PeerID) -> bool: | |||
async def run(self) -> AsyncIterator[torch.Tensor]: | |||
"""Run all-reduce, return differences between averaged and original tensors as they are computed""" | |||
pending_tasks = set() | |||
if self.sender_timeout is not None: | |||
asyncio.create_task(self._handle_missing_senders()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Such fire-and-forget calls always lead to "Task destroyed but is pending". Please fix that, e.g., save the task to a class field and await
when finalizing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GREAT catch: i've meant to add it to pending_tasks (and i just did so)
hivemind/averaging/allreduce.py
Outdated
@@ -205,28 +262,32 @@ async def rpc_aggregate_part( | |||
) -> AsyncIterator[averaging_pb2.AveragingData]: | |||
"""a peer sends us a part of his tensor; we should average it with other peers and return the difference""" | |||
request: averaging_pb2.AveragingData = await anext(stream) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please handle self.sender_timeout
for this first part as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
hivemind/averaging/allreduce.py
Outdated
async with self.banlock: | ||
if context.remote_id not in self.banned_senders: | ||
self.banned_senders.add(context.remote_id) | ||
self.tensor_part_reducer.on_sender_failed(sender_index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix this code duplication (L305-308, L314-317, L352-355), this is likely to lead to a bug if someone decides to change the ban procedure (but changes only 1-2 of 3 copies).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replaced with a unified await self._ban_sender(peer_id)
call
Co-authored-by: Alexander Borzunov <[email protected]>
Co-authored-by: Alexander Borzunov <[email protected]>
Co-authored-by: Alexander Borzunov <[email protected]>
Co-authored-by: Alexander Borzunov <[email protected]>
…ault_tolerant_allreduce
Test cases:
Sanity checks: