Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support different AMP & buffer configurations in one experiment, fix minor bugs #389

Merged
merged 20 commits into from
Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hivemind/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
from hivemind.optim.base import DecentralizedOptimizerBase
from hivemind.optim.collaborative import CollaborativeOptimizer
from hivemind.optim.grad_scaler import HivemindGradScaler
from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
69 changes: 47 additions & 22 deletions hivemind/optim/collaborative.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from hivemind.dht.crypto import RSASignatureValidator
from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
from hivemind.optim.base import DecentralizedOptimizerBase
from hivemind.optim.grad_scaler import HivemindGradScaler
from hivemind.optim.performance_ema import PerformanceEMA
from hivemind.utils import Endpoint, get_dht_time, get_logger
from hivemind.utils import get_dht_time, get_logger

logger = get_logger(__name__)
LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
Expand Down Expand Up @@ -147,6 +148,8 @@ def __init__(
self.status_loglevel = logging.INFO if verbose else logging.DEBUG
self.averager = self._make_averager(**kwargs)

self._step_supports_amp_scaling = self.reuse_grad_buffers # enable custom execution with torch GradScaler

self.training_progress_key = f"{self.prefix}_progress"
self.local_samples_accumulated = 0 # a number of local samples accumulated since last optimizer update
self.local_updates_accumulated = 0 # a number of calls to step() since last optimizer update
Expand Down Expand Up @@ -197,6 +200,8 @@ def load_state_from_peers(self, **kwargs):
try:
self.averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
break
except KeyboardInterrupt:
raise
except BaseException as e:
logger.exception(f"Failed to load state from peers: {e}, retrying ...")
continue
Expand All @@ -205,13 +210,16 @@ def load_state_from_peers(self, **kwargs):
self.reset_accumulated_grads_()
self.update_scheduler()

def step(self, batch_size: Optional[int] = None, **kwargs):
def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindGradScaler] = None, **kwargs):
"""
Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters

:param batch_size: optional override for batch_size_per_step from init
:param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
:note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
"""
if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler).")
if self.batch_size_per_step is None:
if batch_size is None:
raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
Expand All @@ -227,6 +235,13 @@ def step(self, batch_size: Optional[int] = None, **kwargs):
self.averager.local_step = self.collaboration_state.optimizer_step
logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")

if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
self.local_samples_accumulated = self.local_steps_accumulated = 0
self.reset_accumulated_grads_()
self.should_report_progress.set()
return

if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
logger.warning(
f"Training step took {get_dht_time() - self.last_step_time}, "
Expand All @@ -251,6 +266,10 @@ def step(self, batch_size: Optional[int] = None, **kwargs):

# divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
self.apply_accumulated_grads_(scale_by=1.0 / self.local_updates_accumulated)
if grad_scaler is not None:
with grad_scaler.running_global_step():
assert grad_scaler.unscale_(self)

current_step, group_info = self.averager.local_step, None

if self.collaboration_state.num_peers > 1:
Expand Down Expand Up @@ -279,13 +298,21 @@ def step(self, batch_size: Optional[int] = None, **kwargs):
f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s).",
)

self.opt.step()
if grad_scaler is not None:
with grad_scaler.running_global_step():
assert grad_scaler.step(self)
else:
self.opt.step()

self.reset_accumulated_grads_()
self.local_samples_accumulated = self.local_updates_accumulated = 0
self.collaboration_state.register_step(current_step + 1)
self.averager.local_step = current_step + 1
self.collaboration_state_updated.set()
self.update_scheduler()
if grad_scaler is not None:
with grad_scaler.running_global_step():
assert grad_scaler.update()

logger.log(self.status_loglevel, f"Optimizer step: done!")

Expand Down Expand Up @@ -344,38 +371,36 @@ def accumulated_grads(self) -> Iterator[torch.Tensor]:
"""local gradient accumulators"""
if self.reuse_grad_buffers:
yield from self._grad_buffers()
elif self._grads is None:
with torch.no_grad():
self._grads = [
torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()
]
return

if self._grads is None:
self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
yield from self._grads
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wold actually an error with reuse_grad_buffers=True, but it worked because noone asked for more than len(grad_buffers) elements


@torch.no_grad()
def accumulate_grads_(self, batch_size: int):
"""add current gradients to grad accumulators (if any)"""
if self.reuse_grad_buffers:
return # user is responsible for accumulating gradients in .grad buffers
alpha = float(batch_size) / self.batch_size_per_step
for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
# user is responsible for accumulating gradients in .grad buffers
assert batch_size == self.batch_size_per_step, "Custom batch size is not supported if reuse_grad_buffers"
else:
alpha = float(batch_size) / self.batch_size_per_step
for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)

@torch.no_grad()
def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
if self.reuse_grad_buffers:
Copy link
Member Author

@justheuristic justheuristic Sep 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This previously caused a bug where reuse=True peers would not be scaled by scale_by. As a result, if there was a mix of reuse=True and reuse=False peers, reuse=True would have larger gradients and dominate the reuse=False peers.

return
for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
grad_buf[...] = grad_acc.to(grad_buf.device)
if scale_by is not None:
if not self.reuse_grad_buffers:
for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
grad_buf.copy_(grad_acc.to(grad_buf.device), non_blocking=True)
if scale_by is not None:
for grad_buf in self._grad_buffers():
grad_buf.mul_(scale_by)

@torch.no_grad()
def reset_accumulated_grads_(self):
if self.reuse_grad_buffers:
self.opt.zero_grad()
else:
for grad_buf in self.accumulated_grads():
grad_buf.zero_()
for grad_buf in self.accumulated_grads():
grad_buf.zero_()

def report_training_progress(self):
"""Periodically publish metadata and the current number of samples accumulated towards the next step"""
Expand Down
83 changes: 83 additions & 0 deletions hivemind/optim/grad_scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import contextlib
from typing import Dict, Optional

import torch
from torch.cuda.amp import GradScaler as TorchGradScaler
from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
from torch.optim import Optimizer

from hivemind.optim.base import DecentralizedOptimizerBase
from hivemind.utils.logging import get_logger

logger = get_logger(__name__)


class HivemindGradScaler(TorchGradScaler):
"""
A thin wrapper over pytorch GradScaler that supports hivemind-style training with CollaborativeOptimizer, namely:
- bypass .unscale_ and .update calls in order to accumulate gradients over several steps
- limit increasing gradient scale to only immediately after global optimizer steps
- allow training with some or all master parameters in fp16
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._is_running_global_step = False
self._optimizer_states_to_reset = set()

@contextlib.contextmanager
def running_global_step(self):
was_running, self._is_running_global_step = self._is_running_global_step, True
try:
yield
finally:
self._is_running_global_step = was_running

def unscale_(self, optimizer: Optimizer) -> bool:
assert isinstance(optimizer, DecentralizedOptimizerBase)
if self._is_running_global_step:
super().unscale_(optimizer.opt)
return True
else:
self._check_inf_per_device(optimizer.opt)
self._optimizer_states_to_reset.add(id(optimizer))
return False

def step(self, optimizer: Optimizer, *args, **kwargs) -> bool:
assert isinstance(optimizer, DecentralizedOptimizerBase)
if self._is_running_global_step:
if self.are_grads_finite(optimizer):
super().step(optimizer.opt, *args, **kwargs)
else:
logger.warning("Skipping global step due to gradient over/underflow")
return True
else:
super().step(optimizer)
self._optimizer_states_to_reset.add(id(optimizer))
return False

def update(self, new_scale: Optional[float] = None) -> bool:
total_infs = 0
for optimizer_state in self._per_optimizer_states.values():
total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())

if self._is_running_global_step or total_infs != 0:
# note: we update either during actual optimizer step or if we need to reduce scale due to NaN
super().update(new_scale)
return True
else:
for opt_id in self._optimizer_states_to_reset:
self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
self._optimizer_states_to_reset.clear()
return False

def _unscale_grads_(
self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
) -> Dict[torch.device, torch.Tensor]:
# note: the code below sets allow_fp16=True to allow training with master weights (partially) in fp16
# inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it ignore the allow_fp16 value always setting it to True?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's the same trick that fairscale uses to allow training without master fp32 weights
https://github.com/facebookresearch/fairscale/blob/main/fairscale/optim/grad_scaler.py
(got referred there by @TimDettmers)

Added a quick comment explaining that


def are_grads_finite(self, optimizer: DecentralizedOptimizerBase) -> bool:
assert isinstance(optimizer, DecentralizedOptimizerBase)
return not sum(v.item() for v in self._check_inf_per_device(optimizer.opt).values())