Skip to content

Commit

Permalink
Support PyTorch 2.0.0 (#559)
Browse files Browse the repository at this point in the history
- Fix LRSchedulerBase
- Handle None after .zero_grad() in torch 2.0.0
- Use set_to_none=True by default in torch>=2.0
- Add set_to_none param to TrainingStateAverager.step()

Co-authored-by: Aleksandr Borzunov <[email protected]>
(cherry picked from commit 98531ce)
  • Loading branch information
justheuristic authored and mryab committed Mar 31, 2023
1 parent ac2e85c commit 6a21a73
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 17 deletions.
3 changes: 1 addition & 2 deletions examples/albert/run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from transformers.trainer_utils import is_main_process

from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
from hivemind.optim.state_averager import LRSchedulerBase
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.networking import log_visible_maddrs

Expand All @@ -33,8 +34,6 @@
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)

LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)


def setup_transformers_logging(process_rank: int):
if is_main_process(process_rank):
Expand Down
6 changes: 5 additions & 1 deletion hivemind/moe/server/module_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack
from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor

LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
logger = get_logger(__name__)

try:
LRSchedulerBase = torch.optim.lr_scheduler.LRScheduler
except AttributeError: # torch < 2.0.0
LRSchedulerBase = torch.optim.lr_scheduler._LRScheduler


class ModuleBackend:
"""
Expand Down
14 changes: 8 additions & 6 deletions hivemind/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from hivemind.optim.grad_scaler import GradScaler
from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
from hivemind.optim.state_averager import (
ZERO_GRAD_SET_TO_NONE_DEFAULT,
LRSchedulerBase,
OptimizerFactory,
Parameters,
Expand Down Expand Up @@ -621,7 +622,10 @@ def _load_averaged_gradients_into_optimizer_(self):
with torch.no_grad(), self.grad_averager.get_tensors() as averaged_gradients:
assert len(averaged_gradients) == len(optimized_parameters)
for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
opt_param.grad.copy_(averaged_grad, non_blocking=True)
if opt_param.grad is None:
opt_param.grad = averaged_grad.clone()
else:
opt_param.grad.copy_(averaged_grad, non_blocking=True)

self.grad_averager.notify_used_averaged_gradients()

Expand All @@ -634,7 +638,7 @@ def _load_local_gradients_into_optimizer(self):
# - if not offload_optimizer, we must un-scale gradients (divide them by the number of accumulation steps)
self._load_averaged_gradients_into_optimizer_()

def zero_grad(self, set_to_none: bool = False):
def zero_grad(self, set_to_none: bool = ZERO_GRAD_SET_TO_NONE_DEFAULT):
"""Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
raise ValueError(
Expand All @@ -643,11 +647,9 @@ def zero_grad(self, set_to_none: bool = False):
)
for param_group in self.param_groups:
for param in param_group["params"]:
if param.grad is None:
pass
elif set_to_none:
if set_to_none:
param.grad = None
else:
elif param.grad is not None:
param.grad.zero_()

def _should_load_state_from_peers(self) -> bool:
Expand Down
22 changes: 19 additions & 3 deletions hivemind/optim/state_averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union

import torch
from packaging.version import Version

import hivemind
from hivemind.averaging import DecentralizedAverager
Expand All @@ -22,7 +23,12 @@
Parameters = Iterable[torch.Tensor]
ParamGroups = Iterable[Dict[str, Any]]
TorchOptimizer = torch.optim.Optimizer
LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
if Version(torch.__version__).major >= 2:
ZERO_GRAD_SET_TO_NONE_DEFAULT = True
LRSchedulerBase = torch.optim.lr_scheduler.LRScheduler
else:
ZERO_GRAD_SET_TO_NONE_DEFAULT = False
LRSchedulerBase = torch.optim.lr_scheduler._LRScheduler
OptimizerFactory = Callable[[Union[Parameters, ParamGroups]], TorchOptimizer]
SchedulerFactory = Callable[[TorchOptimizer], LRSchedulerBase]

Expand Down Expand Up @@ -332,6 +338,7 @@ def step(
averaging_control: Optional[StepControl] = None,
wait_for_trigger: Optional[Callable[[], Any]] = None,
grad_scaler: Optional[GradScaler] = None,
set_to_none: bool = ZERO_GRAD_SET_TO_NONE_DEFAULT,
averaging_opts: Optional[Dict[str, Any]] = None,
):
"""
Expand All @@ -353,6 +360,8 @@ def step(
:param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
:note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
:param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
:param set_to_none: if True, zero_grad sets local gradients to None instead of zero tensors
(default in PyTorch 2.0+)
:param averaging_opts: a dict of keyword arguments forwarded into averaging round
"""
if delay_averaging is None:
Expand Down Expand Up @@ -430,6 +439,7 @@ def step(
averaging_round,
averaging_control,
grad_scaler,
set_to_none,
**averaging_opts or {},
)
self.pending_updates.add(pending_update)
Expand Down Expand Up @@ -472,6 +482,7 @@ def _do(
averaging_round: bool,
averaging_control: Optional[StepControl],
grad_scaler: Optional[GradScaler],
set_to_none: bool,
timeout: Optional[float] = None,
**kwargs,
):
Expand Down Expand Up @@ -515,7 +526,9 @@ def _do(
self.optimizer.zero_grad()
if self.offload_optimizer:
for parameter in self.main_parameters:
if parameter.grad is not None:
if set_to_none:
parameter.grad = None
elif parameter.grad is not None:
parameter.grad.zero_()

self._update_scheduler()
Expand Down Expand Up @@ -566,7 +579,10 @@ def _load_local_grads_into_optimizer_(self):
opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
for main_param, opt_param in zip(self.main_parameters, opt_parameters):
if main_param.grad is not None:
opt_param.grad.copy_(main_param.grad, non_blocking=True)
if opt_param.grad is None:
opt_param.grad = main_param.grad.clone()
else:
opt_param.grad.copy_(main_param.grad, non_blocking=True)

@torch.no_grad()
def _apply_optimizer_parameters_(self):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
PyYAML
torch>=1.9.0,<2.0.0
torch>=1.9.0
numpy>=1.17
scipy>=1.2.1
prefetch_generator>=1.0.1
Expand Down
14 changes: 10 additions & 4 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from hivemind.optim.optimizer import Optimizer
from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
from hivemind.optim.progress_tracker import ProgressTracker
from hivemind.optim.state_averager import TrainingStateAverager
from hivemind.optim.state_averager import ZERO_GRAD_SET_TO_NONE_DEFAULT, TrainingStateAverager
from hivemind.utils.crypto import RSAPrivateKey


Expand Down Expand Up @@ -79,8 +79,11 @@ def test_grad_averager(grad_averager_factory: GradientAveragerFactory):
assert torch.allclose(model2.w.grad, ref_average)

# after no longer use_averaged_gradients
assert not torch.allclose(model1.w.grad, ref_average)
assert not torch.allclose(model2.w.grad, ref_average)
if ZERO_GRAD_SET_TO_NONE_DEFAULT: # averager1 has reuse_grad_buffers=False
assert model1.w.grad is None
else:
assert not torch.allclose(model1.w.grad, ref_average)
assert not torch.allclose(model2.w.grad, ref_average) # averager2 has reuse_grad_buffers=True


@pytest.mark.forked
Expand Down Expand Up @@ -151,7 +154,10 @@ def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch
F.mse_loss(model2(x), -torch.ones(3)).backward()
avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False)

assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), "zero grad did not trigger"
if ZERO_GRAD_SET_TO_NONE_DEFAULT:
assert model1.weight.grad is None and model2.weight.grad is None, ".zero_grad() wasn't called"
else:
assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), ".zero_grad() wasn't called"
assert model1(x).mean() > 0.5 and model2(x).mean() < -0.5, "models did not train properly"
assert torch.allclose(extras1[0], extras2[0]), "first extra tensors were not averaged"
assert torch.allclose(extras1[1], extras2[1]), "second extra tensors were not averaged"
Expand Down

0 comments on commit 6a21a73

Please sign in to comment.