Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Statistics averaging #229

Merged
merged 6 commits into from
Apr 18, 2021
Merged

Statistics averaging #229

merged 6 commits into from
Apr 18, 2021

Conversation

nevec
Copy link
Collaborator

@nevec nevec commented Apr 16, 2021

This PR extends TrainingAverager with optimizer's stats averaging. It should serve as base for subsequent decentralized adaptive optims implementations.

  • implement averaging feature
  • add basic test
  • move hivemind.client.optim into separate module
  • fix typo, which has lead to infinite recursion

@nevec nevec requested a review from justheuristic April 16, 2021 21:51
@@ -46,7 +49,7 @@ def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, aver
def step(self, wait: bool = True, **kwargs):
""" Average optimizer weights and gradients with peers. """
if not wait:
return run_in_background(self.step, wait=False, **kwargs)
return run_in_background(self.step, wait=True, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

@justheuristic justheuristic requested review from mryab and borzunov and removed request for mryab April 17, 2021 12:17
grad_avg = 0.5 * (x1.grad + x2.grad)
stats_avg = 0.5 * (opt1.state[x1]["exp_avg_sq"] + opt2.state[x2]["exp_avg_sq"])

f1 = averager1.step(wait=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f1 = averager1.step(wait=False)
# We set wait=False to test hivemind.utils.run_in_background() usage
f1 = averager1.step(wait=False)

nit: Using wait=False and then waiting for the result looked surprising to me. I'd suggest clarifying this with the comment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the main purpose of wait=False is to prevent deadlock, when averager1 waits for averager2 to join.
Fix: write this explicitly in comment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sure, missed that! Thanks for the explanation :)

Copy link
Member

@borzunov borzunov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I approve the PR, but ask you to consider the minor change I've suggested :)

@nevec nevec merged commit 8c3bd93 into master Apr 18, 2021
@nevec nevec deleted the extend_averager branch April 18, 2021 13:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants