Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PyTorch 2.0.0 #559

Merged
merged 6 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/albert/run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from transformers.trainer_utils import is_main_process

from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
from hivemind.optim.state_averager import LRSchedulerBase
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.networking import log_visible_maddrs

Expand All @@ -33,8 +34,6 @@
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)

LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)


def setup_transformers_logging(process_rank: int):
if is_main_process(process_rank):
Expand Down
6 changes: 5 additions & 1 deletion hivemind/moe/server/module_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack
from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor

LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
logger = get_logger(__name__)

try:
LRSchedulerBase = torch.optim.lr_scheduler.LRScheduler
except AttributeError: # torch < 2.0.0
LRSchedulerBase = torch.optim.lr_scheduler._LRScheduler


class ModuleBackend:
"""
Expand Down
14 changes: 8 additions & 6 deletions hivemind/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from hivemind.optim.grad_scaler import GradScaler
from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
from hivemind.optim.state_averager import (
ZERO_GRAD_SET_TO_NONE_DEFAULT,
LRSchedulerBase,
OptimizerFactory,
Parameters,
Expand Down Expand Up @@ -621,7 +622,10 @@ def _load_averaged_gradients_into_optimizer_(self):
with torch.no_grad(), self.grad_averager.get_tensors() as averaged_gradients:
assert len(averaged_gradients) == len(optimized_parameters)
for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
opt_param.grad.copy_(averaged_grad, non_blocking=True)
if opt_param.grad is None:
opt_param.grad = averaged_grad.clone()
else:
opt_param.grad.copy_(averaged_grad, non_blocking=True)

self.grad_averager.notify_used_averaged_gradients()

Expand All @@ -634,7 +638,7 @@ def _load_local_gradients_into_optimizer(self):
# - if not offload_optimizer, we must un-scale gradients (divide them by the number of accumulation steps)
self._load_averaged_gradients_into_optimizer_()

def zero_grad(self, set_to_none: bool = False):
def zero_grad(self, set_to_none: bool = ZERO_GRAD_SET_TO_NONE_DEFAULT):
"""Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
raise ValueError(
Expand All @@ -643,11 +647,9 @@ def zero_grad(self, set_to_none: bool = False):
)
for param_group in self.param_groups:
for param in param_group["params"]:
if param.grad is None:
pass
elif set_to_none:
if set_to_none:
param.grad = None
else:
elif param.grad is not None:
param.grad.zero_()

def _should_load_state_from_peers(self) -> bool:
Expand Down
22 changes: 19 additions & 3 deletions hivemind/optim/state_averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union

import torch
from packaging.version import Version

import hivemind
from hivemind.averaging import DecentralizedAverager
Expand All @@ -22,7 +23,12 @@
Parameters = Iterable[torch.Tensor]
ParamGroups = Iterable[Dict[str, Any]]
TorchOptimizer = torch.optim.Optimizer
LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
if Version(torch.__version__).major >= 2:
ZERO_GRAD_SET_TO_NONE_DEFAULT = True
LRSchedulerBase = torch.optim.lr_scheduler.LRScheduler
else:
ZERO_GRAD_SET_TO_NONE_DEFAULT = False
LRSchedulerBase = torch.optim.lr_scheduler._LRScheduler
OptimizerFactory = Callable[[Union[Parameters, ParamGroups]], TorchOptimizer]
SchedulerFactory = Callable[[TorchOptimizer], LRSchedulerBase]

Expand Down Expand Up @@ -332,6 +338,7 @@ def step(
averaging_control: Optional[StepControl] = None,
wait_for_trigger: Optional[Callable[[], Any]] = None,
grad_scaler: Optional[GradScaler] = None,
set_to_none: bool = ZERO_GRAD_SET_TO_NONE_DEFAULT,
averaging_opts: Optional[Dict[str, Any]] = None,
):
"""
Expand All @@ -353,6 +360,8 @@ def step(
:param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
:note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
:param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
:param set_to_none: if True, zero_grad sets local gradients to None instead of zero tensors
(default in PyTorch 2.0+)
:param averaging_opts: a dict of keyword arguments forwarded into averaging round
"""
if delay_averaging is None:
Expand Down Expand Up @@ -430,6 +439,7 @@ def step(
averaging_round,
averaging_control,
grad_scaler,
set_to_none,
**averaging_opts or {},
)
self.pending_updates.add(pending_update)
Expand Down Expand Up @@ -472,6 +482,7 @@ def _do(
averaging_round: bool,
averaging_control: Optional[StepControl],
grad_scaler: Optional[GradScaler],
set_to_none: bool,
timeout: Optional[float] = None,
**kwargs,
):
Expand Down Expand Up @@ -515,7 +526,9 @@ def _do(
self.optimizer.zero_grad()
if self.offload_optimizer:
for parameter in self.main_parameters:
if parameter.grad is not None:
if set_to_none:
parameter.grad = None
elif parameter.grad is not None:
parameter.grad.zero_()

self._update_scheduler()
Expand Down Expand Up @@ -566,7 +579,10 @@ def _load_local_grads_into_optimizer_(self):
opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
for main_param, opt_param in zip(self.main_parameters, opt_parameters):
if main_param.grad is not None:
opt_param.grad.copy_(main_param.grad, non_blocking=True)
if opt_param.grad is None:
opt_param.grad = main_param.grad.clone()
else:
opt_param.grad.copy_(main_param.grad, non_blocking=True)

@torch.no_grad()
def _apply_optimizer_parameters_(self):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
PyYAML
torch>=1.9.0,<2.0.0
torch>=1.9.0
Copy link
Member

@borzunov borzunov Mar 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll have torch 1.13.0 on Python 3.7, torch 2.0.0 on Python 3.8-3.10. This allows us to test both torch 1.13.0 and 2.0.0.

We can drop Python 3.7 and PyTorch < 2.0 support when necessary.

numpy>=1.17
scipy>=1.2.1
prefetch_generator>=1.0.1
Expand Down
14 changes: 10 additions & 4 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from hivemind.optim.optimizer import Optimizer
from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
from hivemind.optim.progress_tracker import ProgressTracker
from hivemind.optim.state_averager import TrainingStateAverager
from hivemind.optim.state_averager import ZERO_GRAD_SET_TO_NONE_DEFAULT, TrainingStateAverager
from hivemind.utils.crypto import RSAPrivateKey


Expand Down Expand Up @@ -79,8 +79,11 @@ def test_grad_averager(grad_averager_factory: GradientAveragerFactory):
assert torch.allclose(model2.w.grad, ref_average)

# after no longer use_averaged_gradients
assert not torch.allclose(model1.w.grad, ref_average)
assert not torch.allclose(model2.w.grad, ref_average)
if ZERO_GRAD_SET_TO_NONE_DEFAULT: # averager1 has reuse_grad_buffers=False
assert model1.w.grad is None
else:
assert not torch.allclose(model1.w.grad, ref_average)
assert not torch.allclose(model2.w.grad, ref_average) # averager2 has reuse_grad_buffers=True


@pytest.mark.forked
Expand Down Expand Up @@ -151,7 +154,10 @@ def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch
F.mse_loss(model2(x), -torch.ones(3)).backward()
avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False)

assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), "zero grad did not trigger"
if ZERO_GRAD_SET_TO_NONE_DEFAULT:
assert model1.weight.grad is None and model2.weight.grad is None, ".zero_grad() wasn't called"
else:
assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), ".zero_grad() wasn't called"
assert model1(x).mean() > 0.5 and model2(x).mean() < -0.5, "models did not train properly"
assert torch.allclose(extras1[0], extras2[0]), "first extra tensors were not averaged"
assert torch.allclose(extras1[1], extras2[1]), "second extra tensors were not averaged"
Expand Down