Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve All-Reduce fault-tolerance #423

Merged
merged 61 commits into from
Dec 14, 2021
Merged

Conversation

justheuristic
Copy link
Member

@justheuristic justheuristic commented Dec 12, 2021

  • allow AllreduceRunner to tolerate clients that
    • do not send some of their local tensors
    • do not show up at at all after matchmaking is over
  • allow AllreduceRunner to tolerate full/aux peers that do not send some or all results
  • introduce timeout after which sender/reducer is considered failed
  • AllreduceRunner & DecentralizedAverager will no longer _send_error_to_peer
    • log spam is gone!
  • report allreduce integrity
    • TensorPartReducer will report the fraction of expected parts received if that fraction is not 1
    • TensorPartContainer will report the fraction of parts that did not fail if that fraction is not 1
  • miscellaneous improvements to Optimizer
    • set good default sender/reducer timeouts
    • pre-schedule state averaging ahead of time
    • no longer block the entire peer if it is time to pre-schedule gradients but background state averaging is still underway

Test cases:

  • test with peers that fail early
  • test with peers that fail to send a certain part
  • test with peers that fail to reduce their part
  • test cancelling

Sanity checks:

  • run tests 100 times
  • benchmark_optimizer
  • test env 64+ nodes 4+ hours

@codecov
Copy link

codecov bot commented Dec 12, 2021

Codecov Report

Merging #423 (a803aa8) into master (896885a) will increase coverage by 0.81%.
The diff coverage is 93.22%.

@@            Coverage Diff             @@
##           master     #423      +/-   ##
==========================================
+ Coverage   83.40%   84.22%   +0.81%     
==========================================
  Files          77       77              
  Lines        7809     7891      +82     
==========================================
+ Hits         6513     6646     +133     
+ Misses       1296     1245      -51     
Impacted Files Coverage Δ
hivemind/optim/grad_scaler.py 30.98% <0.00%> (+0.84%) ⬆️
hivemind/averaging/averager.py 87.65% <75.00%> (+2.10%) ⬆️
hivemind/averaging/allreduce.py 92.55% <92.47%> (+14.84%) ⬆️
hivemind/averaging/partition.py 98.87% <100.00%> (+0.85%) ⬆️
hivemind/optim/experimental/optimizer.py 62.53% <100.00%> (+0.67%) ⬆️
hivemind/utils/asyncio.py 100.00% <100.00%> (+0.98%) ⬆️
hivemind/optim/experimental/state_averager.py 86.19% <0.00%> (+0.24%) ⬆️
... and 5 more

Copy link
Member

@borzunov borzunov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I've went through everything but tests and left my comments.

await queue.put(loop.run_in_executor(executor, func, *args))
await queue.put(None)
except BaseException as e:
await queue.put(e) # note: there is no chance that iterables
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfinished comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixd

task.cancel()
task.cancel()
if task.done() and not task.cancelled():
task.exception()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean raise task.exception()? If yes, it does not seem necessary since await task will already raise this exception.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I specifically mean "if task did not send result or throw exception but we died anyway, silence the "task ... was never retrieved""

Copy link
Member

@borzunov borzunov Dec 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove L136-137 because .cancel() already suppresses this message.

Indeed, the message is printed only when task.__log_traceback == True (source), and task.cancel() sets it to False (source) exactly as task.exception() does (source).

self.current_senders -= 1
if self.current_part_accumulated_from == self.current_senders:
self.current_part_future.set_result(self.accumulator.div_(self.denominator))
self.reset_accumulators()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please extract these two lines to a function, it is an extremely error-prone code duplication.

For instance, imagine that someone writes a CenteredClip reducer and only changes L229-230 (not L239-240): since the latter are executed only during a failure, the unit tests are likely to miss that mistake.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -132,6 +148,8 @@ def should_delay_results(self, peer_id: PeerID) -> bool:
async def run(self) -> AsyncIterator[torch.Tensor]:
"""Run all-reduce, return differences between averaged and original tensors as they are computed"""
pending_tasks = set()
if self.sender_timeout is not None:
asyncio.create_task(self._handle_missing_senders())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such fire-and-forget calls always lead to "Task destroyed but is pending". Please fix that, e.g., save the task to a class field and await when finalizing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GREAT catch: i've meant to add it to pending_tasks (and i just did so)

@@ -205,28 +262,32 @@ async def rpc_aggregate_part(
) -> AsyncIterator[averaging_pb2.AveragingData]:
"""a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
request: averaging_pb2.AveragingData = await anext(stream)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please handle self.sender_timeout for this first part as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

async with self.banlock:
if context.remote_id not in self.banned_senders:
self.banned_senders.add(context.remote_id)
self.tensor_part_reducer.on_sender_failed(sender_index)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix this code duplication (L305-308, L314-317, L352-355), this is likely to lead to a bug if someone decides to change the ban procedure (but changes only 1-2 of 3 copies).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replaced with a unified await self._ban_sender(peer_id) call

@borzunov borzunov changed the title Allreduce Fault Tolerance Improve All-Reduce fault-tolerance Dec 13, 2021
@justheuristic justheuristic merged commit 6da8683 into master Dec 14, 2021
@justheuristic justheuristic deleted the fault_tolerant_allreduce branch December 14, 2021 09:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants