diff --git a/examples/albert/run_trainer.py b/examples/albert/run_trainer.py index 7fa550a92..9e9445cf8 100755 --- a/examples/albert/run_trainer.py +++ b/examples/albert/run_trainer.py @@ -18,6 +18,7 @@ from transformers.trainer_utils import is_main_process from hivemind import DHT, Float16Compression, Optimizer, get_dht_time +from hivemind.optim.state_averager import LRSchedulerBase from hivemind.utils.logging import get_logger, use_hivemind_log_handler from hivemind.utils.networking import log_visible_maddrs @@ -33,8 +34,6 @@ use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) -LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None) - def setup_transformers_logging(process_rank: int): if is_main_process(process_rank): diff --git a/hivemind/moe/server/module_backend.py b/hivemind/moe/server/module_backend.py index 5688b8ecc..199cc9f28 100644 --- a/hivemind/moe/server/module_backend.py +++ b/hivemind/moe/server/module_backend.py @@ -8,9 +8,13 @@ from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor -LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None) logger = get_logger(__name__) +try: + LRSchedulerBase = torch.optim.lr_scheduler.LRScheduler +except AttributeError: # torch < 2.0.0 + LRSchedulerBase = torch.optim.lr_scheduler._LRScheduler + class ModuleBackend: """ diff --git a/hivemind/optim/optimizer.py b/hivemind/optim/optimizer.py index 97c9436ab..ef0a05cc9 100644 --- a/hivemind/optim/optimizer.py +++ b/hivemind/optim/optimizer.py @@ -15,6 +15,7 @@ from hivemind.optim.grad_scaler import GradScaler from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker from hivemind.optim.state_averager import ( + ZERO_GRAD_SET_TO_NONE_DEFAULT, LRSchedulerBase, OptimizerFactory, Parameters, @@ -621,7 +622,10 @@ def _load_averaged_gradients_into_optimizer_(self): with torch.no_grad(), self.grad_averager.get_tensors() as averaged_gradients: assert len(averaged_gradients) == len(optimized_parameters) for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients): - opt_param.grad.copy_(averaged_grad, non_blocking=True) + if opt_param.grad is None: + opt_param.grad = averaged_grad.clone() + else: + opt_param.grad.copy_(averaged_grad, non_blocking=True) self.grad_averager.notify_used_averaged_gradients() @@ -634,7 +638,7 @@ def _load_local_gradients_into_optimizer(self): # - if not offload_optimizer, we must un-scale gradients (divide them by the number of accumulation steps) self._load_averaged_gradients_into_optimizer_() - def zero_grad(self, set_to_none: bool = False): + def zero_grad(self, set_to_none: bool = ZERO_GRAD_SET_TO_NONE_DEFAULT): """Reset gradients from model. If reuse_grad_buffers=True, this will raise an error.""" if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers: raise ValueError( @@ -643,11 +647,9 @@ def zero_grad(self, set_to_none: bool = False): ) for param_group in self.param_groups: for param in param_group["params"]: - if param.grad is None: - pass - elif set_to_none: + if set_to_none: param.grad = None - else: + elif param.grad is not None: param.grad.zero_() def _should_load_state_from_peers(self) -> bool: diff --git a/hivemind/optim/state_averager.py b/hivemind/optim/state_averager.py index 794260fc4..f7a94f7b3 100644 --- a/hivemind/optim/state_averager.py +++ b/hivemind/optim/state_averager.py @@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union import torch +from packaging.version import Version import hivemind from hivemind.averaging import DecentralizedAverager @@ -22,7 +23,12 @@ Parameters = Iterable[torch.Tensor] ParamGroups = Iterable[Dict[str, Any]] TorchOptimizer = torch.optim.Optimizer -LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None) +if Version(torch.__version__).major >= 2: + ZERO_GRAD_SET_TO_NONE_DEFAULT = True + LRSchedulerBase = torch.optim.lr_scheduler.LRScheduler +else: + ZERO_GRAD_SET_TO_NONE_DEFAULT = False + LRSchedulerBase = torch.optim.lr_scheduler._LRScheduler OptimizerFactory = Callable[[Union[Parameters, ParamGroups]], TorchOptimizer] SchedulerFactory = Callable[[TorchOptimizer], LRSchedulerBase] @@ -332,6 +338,7 @@ def step( averaging_control: Optional[StepControl] = None, wait_for_trigger: Optional[Callable[[], Any]] = None, grad_scaler: Optional[GradScaler] = None, + set_to_none: bool = ZERO_GRAD_SET_TO_NONE_DEFAULT, averaging_opts: Optional[Dict[str, Any]] = None, ): """ @@ -353,6 +360,8 @@ def step( :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step :note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_ + :param set_to_none: if True, zero_grad sets local gradients to None instead of zero tensors + (default in PyTorch 2.0+) :param averaging_opts: a dict of keyword arguments forwarded into averaging round """ if delay_averaging is None: @@ -430,6 +439,7 @@ def step( averaging_round, averaging_control, grad_scaler, + set_to_none, **averaging_opts or {}, ) self.pending_updates.add(pending_update) @@ -472,6 +482,7 @@ def _do( averaging_round: bool, averaging_control: Optional[StepControl], grad_scaler: Optional[GradScaler], + set_to_none: bool, timeout: Optional[float] = None, **kwargs, ): @@ -515,7 +526,9 @@ def _do( self.optimizer.zero_grad() if self.offload_optimizer: for parameter in self.main_parameters: - if parameter.grad is not None: + if set_to_none: + parameter.grad = None + elif parameter.grad is not None: parameter.grad.zero_() self._update_scheduler() @@ -566,7 +579,10 @@ def _load_local_grads_into_optimizer_(self): opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]] for main_param, opt_param in zip(self.main_parameters, opt_parameters): if main_param.grad is not None: - opt_param.grad.copy_(main_param.grad, non_blocking=True) + if opt_param.grad is None: + opt_param.grad = main_param.grad.clone() + else: + opt_param.grad.copy_(main_param.grad, non_blocking=True) @torch.no_grad() def _apply_optimizer_parameters_(self): diff --git a/requirements.txt b/requirements.txt index f7bc952d9..9c483ac16 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ PyYAML -torch>=1.9.0,<2.0.0 +torch>=1.9.0 numpy>=1.17 scipy>=1.2.1 prefetch_generator>=1.0.1 diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index e2361ae03..c859e3879 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -15,7 +15,7 @@ from hivemind.optim.optimizer import Optimizer from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager from hivemind.optim.progress_tracker import ProgressTracker -from hivemind.optim.state_averager import TrainingStateAverager +from hivemind.optim.state_averager import ZERO_GRAD_SET_TO_NONE_DEFAULT, TrainingStateAverager from hivemind.utils.crypto import RSAPrivateKey @@ -79,8 +79,11 @@ def test_grad_averager(grad_averager_factory: GradientAveragerFactory): assert torch.allclose(model2.w.grad, ref_average) # after no longer use_averaged_gradients - assert not torch.allclose(model1.w.grad, ref_average) - assert not torch.allclose(model2.w.grad, ref_average) + if ZERO_GRAD_SET_TO_NONE_DEFAULT: # averager1 has reuse_grad_buffers=False + assert model1.w.grad is None + else: + assert not torch.allclose(model1.w.grad, ref_average) + assert not torch.allclose(model2.w.grad, ref_average) # averager2 has reuse_grad_buffers=True @pytest.mark.forked @@ -151,7 +154,10 @@ def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch F.mse_loss(model2(x), -torch.ones(3)).backward() avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False) - assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), "zero grad did not trigger" + if ZERO_GRAD_SET_TO_NONE_DEFAULT: + assert model1.weight.grad is None and model2.weight.grad is None, ".zero_grad() wasn't called" + else: + assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), ".zero_grad() wasn't called" assert model1(x).mean() > 0.5 and model2(x).mean() < -0.5, "models did not train properly" assert torch.allclose(extras1[0], extras2[0]), "first extra tensors were not averaged" assert torch.allclose(extras1[1], extras2[1]), "second extra tensors were not averaged"