Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix minor bugs in GradientAverager #410

Merged
merged 2 commits into from
Nov 16, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions hivemind/optim/experimental/grad_averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import hivemind
from hivemind.averaging import DecentralizedAverager
from hivemind.averaging.control import StepControl
from hivemind.utils import DHTExpiration, get_logger
from hivemind.utils import DHTExpiration, get_dht_time, get_logger

logger = get_logger(__name__)

Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(
if reuse_grad_buffers and accumulate_grads_on is not None:
logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
client_mode = client_mode if client_mode is not None else dht.client_mode
self._parameters = tuple(parameters)
self.parameters = tuple(parameters)
self.reuse_grad_buffers = reuse_grad_buffers
self.warn = warn
self.local_samples_accumulated = 0
Expand All @@ -102,7 +102,7 @@ def __init__(

def _grads_from_parameters(self) -> Iterator[torch.Tensor]:
"""gradient buffers associated with parameters"""
for param in self._parameters:
for param in self.parameters:
if param.grad is None:
param.grad = torch.zeros_like(param)
yield param.grad
Expand Down Expand Up @@ -152,6 +152,7 @@ def step(
weight: Optional[float] = None,
reset_accumulators: bool = True,
control: Optional[StepControl] = None,
timeout: Optional[float] = None,
wait: bool = True,
**kwargs,
):
Expand All @@ -161,12 +162,13 @@ def step(
:param weight: overrides the averaging weight; by default, weight equals the number of accumulated samples
:param reset_accumulators: by default, set local gradient accumulators to zeros after averaging succeeds
:param control: reuse a pre-arranged group of peers (or a matchmaking in progress) from averager.schedule_step
:param timeout: if specified, await for averaging round for at most this number of seconds (if wait=True)
:param wait: if True, await for the step to finish (or fail), otherwise run all-reduce in background
"""
if control is None:
control = self.schedule_step(**kwargs)
control = self.schedule_step(timeout=timeout, **kwargs)
elif len(kwargs) > 0:
RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect.")
raise RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect.")
assert not control.triggered, f"This {type(control)} instance was already used."
self._load_accumulators_into_averager_()
self._accumulators_used_in_step = True
Expand All @@ -175,9 +177,9 @@ def step(
control.weight = self.local_samples_accumulated if weight is None else weight
if reset_accumulators:
self.reset_accumulated_grads_()

control.allow_allreduce()
return control.result() if wait else control

return control.result(timeout) if wait else control

@torch.no_grad()
def _load_accumulators_into_averager_(self):
Expand Down Expand Up @@ -209,11 +211,11 @@ def use_averaged_gradients(self):
self._new_averaged_grads = False
with self.get_tensors() as averaged_grads:
try:
assert len(averaged_grads) == len(self._parameters)
old_grads = [param.grad for param in self._parameters]
for param, new_grad in zip(self._parameters, averaged_grads):
assert len(averaged_grads) == len(self.parameters)
old_grads = [param.grad for param in self.parameters]
for param, new_grad in zip(self.parameters, averaged_grads):
param.grad = new_grad
yield
finally:
for param, old_grad in zip(self._parameters, old_grads):
for param, old_grad in zip(self.parameters, old_grads):
param.grad = old_grad