Skip to content

Commit

Permalink
Support auxiliary peers in CollaborativeOptimizer (#279)
Browse files Browse the repository at this point in the history
* Support auxiliary peers for our ALBERT training

* Update hivemind/optim/collaborative.py

Co-authored-by: justheuristic <[email protected]>

* Remove unnecessary

* Update hivemind/optim/collaborative.py

Co-authored-by: Max Ryabinin <[email protected]>

* Fixes for review

* range(len()) -> enumerate

* Update config.yml

Co-authored-by: justheuristic <[email protected]>
Co-authored-by: Max Ryabinin <[email protected]>
  • Loading branch information
3 people authored Jun 22, 2021
1 parent b6fbae4 commit 86f3c0d
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ jobs:
- image: circleci/python:3.9.1
steps:
- checkout
- run: ulimit -n 4096 # temporary workaround for py39
- restore_cache:
keys:
- py39-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
Expand Down
5 changes: 4 additions & 1 deletion examples/albert/run_first_peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from hivemind.utils.logging import get_logger
import metrics_utils


logger = get_logger(__name__)


Expand Down Expand Up @@ -163,6 +162,10 @@ def upload_checkpoint(self, current_loss):
for peer in metrics_dict]
latest_step = max(item.step for item in metrics)
if latest_step != current_step:
logger.debug(f"Got metrics from {len(metrics)} peers")

for i, metrics_for_peer in enumerate(metrics):
logger.debug(f"{i} peer {metrics_for_peer}")
current_step = latest_step
alive_peers = 0
num_batches = 0
Expand Down
38 changes: 36 additions & 2 deletions hivemind/optim/collaborative.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,43 @@ def step(self, batch_size: Optional[int] = None, **kwargs):
self.collaboration_state_updated.set()
self.update_scheduler()

logger.log(self.status_loglevel, f"Optimizer step: done!")
logger.log(self.status_loglevel, f"Optimizer step: done!")

return group_info
return group_info

def step_aux(self, **kwargs):
"""
Find and assist other peers in averaging without sending local gradients.
:note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
"""

if not self.collaboration_state.ready_for_step:
return

logger.log(self.status_loglevel,
f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
self.collaboration_state = self.fetch_collaboration_state()
self.collaboration_state_updated.set()

with self.lock_collaboration_state:
# divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
current_step, group_info = self.averager.local_step, None
try:
group_info = self.averager.step(timeout=self.averaging_timeout, **kwargs)
if group_info:
logger.log(self.status_loglevel,
f"Averaged tensors successfully with {len(group_info)} peers")
except BaseException as e:
logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")

self.collaboration_state.register_step(current_step + 1)
self.averager.local_step = current_step + 1
self.collaboration_state_updated.set()

logger.log(self.status_loglevel, f"Optimizer step: done!")

return group_info

def _grad_buffers(self) -> Iterator[torch.Tensor]:
""" pytorch-internal gradient buffers """
Expand Down

0 comments on commit 86f3c0d

Please sign in to comment.