Skip to content

Commit

Permalink
Fix initialization of optimizers in DDP Strategy (#11952)
Browse files Browse the repository at this point in the history
Co-authored-by: Rohit Gupta <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
3 people authored and lexierule committed Jun 1, 2022
1 parent f89b181 commit a5f82f5
Show file tree
Hide file tree
Showing 11 changed files with 151 additions and 48 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed torchelastic detection with non-distributed installations ([#13142](https://github.com/PyTorchLightning/pytorch-lightning/pull/13142))
- Fixed logging's step values when multiple dataloaders are used during evaluation ([#12184](https://github.com/PyTorchLightning/pytorch-lightning/pull/12184))
- Fixed epoch logging on train epoch end ([#13025](https://github.com/PyTorchLightning/pytorch-lightning/pull/13025))
- Fixed `DDPStrategy` and `DDPSpawnStrategy` to initialize optimizers only after moving the module to the device ([#11952](https://github.com/PyTorchLightning/pytorch-lightning/pull/11952))


## [1.6.3] - 2022-05-03
Expand Down
34 changes: 31 additions & 3 deletions pytorch_lightning/strategies/bagua.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _BAGUA_AVAILABLE
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.seed import reset_seed

if _BAGUA_AVAILABLE:
Expand Down Expand Up @@ -152,6 +154,33 @@ def _set_node_environment_variables(self) -> None:
os.environ["WORLD_SIZE"] = str(self.world_size)
os.environ["LOCAL_RANK"] = str(self.local_rank)

def setup(self, trainer: "pl.Trainer") -> None:
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
if self._should_run_deadlock_detection():
self._share_information_to_prevent_deadlock()

self.accelerator.setup(trainer)

# move the model to the correct device
self.model_to_device()

trainer_fn = trainer.state.fn

if trainer_fn == TrainerFn.FITTING:
if self._layer_sync and self.model:
self.model = self._layer_sync.apply(self.model)

self.setup_precision_plugin()

if trainer_fn == TrainerFn.FITTING:
# set up optimizers after the module has been moved to the device
# but before the module has been wrapped
self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)

# skip wrapping the model if we are not fitting as no gradients need to be exchanged
self._configure_bagua_model(trainer)

def _check_qadam_optimizer(self) -> None:
has_qadam_optimizer = any([isinstance(opt, QAdamOptimizer) for opt in self.optimizers])

Expand All @@ -160,13 +189,12 @@ def _check_qadam_optimizer(self) -> None:

self._bagua_kwargs["q_adam_optimizer"] = self.optimizers[0]

def configure_ddp(self) -> None:
def _configure_bagua_model(self, trainer: "pl.Trainer") -> None:
model = LightningBaguaModule(self.model) # type: ignore[arg-type]
self._model = self._setup_model(model)

# start the background communication for async algorithm
assert self.lightning_module.trainer is not None
if self.lightning_module.trainer.training and self._bagua_algorithm == "async":
if trainer.training and self._bagua_algorithm == "async":
self.model.bagua_algorithm.resume(self.model) # type: ignore

def _setup_model(self, model: Module) -> BaguaDistributedDataParallel:
Expand Down
32 changes: 20 additions & 12 deletions pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
_TORCH_GREATER_EQUAL_1_10,
_TORCH_GREATER_EQUAL_1_11,
)
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand Down Expand Up @@ -152,24 +153,37 @@ def setup_environment(self) -> None:
super().setup_environment()

def setup(self, trainer: "pl.Trainer") -> None:
super().setup(trainer)
# share ddp pids to all processes
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
if self._should_run_deadlock_detection():
self._share_information_to_prevent_deadlock()

self.accelerator.setup(trainer)

# move the model to the correct device
self.model_to_device()

# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
return

if self._layer_sync:
self.model = self._layer_sync.apply(self.model)
if trainer_fn == TrainerFn.FITTING:
if self._layer_sync:
self.model = self._layer_sync.apply(self.model)

self.setup_precision_plugin()

if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()

self.configure_ddp()
# set up optimizers after the wrapped module has been moved to the device
self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)

if _TORCH_GREATER_EQUAL_1_10 and trainer_fn == TrainerFn.FITTING:
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD

if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
self._enable_model_averaging()

def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
Expand Down Expand Up @@ -223,12 +237,6 @@ def _register_ddp_hooks(self) -> None:
ddp_comm_wrapper=self._ddp_comm_wrapper,
)

if _TORCH_GREATER_EQUAL_1_10 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD

if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
self._enable_model_averaging()

def _enable_model_averaging(self) -> None:
# Only called when PyTorch version >= 1.10
log.detail(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD")
Expand Down
23 changes: 15 additions & 8 deletions pytorch_lightning/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
sync_ddp_if_available,
)
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand Down Expand Up @@ -122,20 +123,22 @@ def _configure_launcher(self):

def setup(self, trainer: "pl.Trainer") -> None:
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
super().setup(trainer)

self.accelerator.setup(trainer)

# move the model to the correct device
self.model_to_device()

trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
return
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = trainer.state.fn
if trainer_fn == TrainerFn.FITTING:
if self._layer_sync:
self.model = self._layer_sync.apply(self.model)

if self._layer_sync:
self.model = self._layer_sync.apply(self.model)
self.setup_precision_plugin()

# skip wrapping the model if we are not fitting as no gradients need to be exchanged
self.configure_ddp()
if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()

def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
Expand Down Expand Up @@ -186,6 +189,10 @@ def configure_ddp(self) -> None:
self.model = self._setup_model(LightningDistributedModule(self.model))
self._register_ddp_hooks()

# set up optimizers after the wrapped module has been moved to the device
self.setup_optimizers(self.lightning_module.trainer)
optimizers_to_device(self.optimizers, self.root_device)

def determine_ddp_device_ids(self):
if self.root_device.type == "cpu":
return None
Expand Down
15 changes: 10 additions & 5 deletions pytorch_lightning/strategies/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,16 @@ def setup_distributed(self) -> None:
def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer)

if trainer.state.fn == TrainerFn.FITTING and self._layer_sync:
self.model = self._layer_sync.apply(self.model)
if trainer.state.fn == TrainerFn.FITTING:
self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)

if self._layer_sync:
self.model = self._layer_sync.apply(self.model)

self.setup_precision_plugin()
self.configure_ddp()
self.barrier()
self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)
self.setup_precision_plugin()

@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
Expand Down Expand Up @@ -183,6 +185,9 @@ def configure_ddp(self) -> None:
# (TODO: need to figure out solution)
self.model_to_device()

# setup optimizers after fully sharded has wrapped the lightning module
self.setup_optimizers(self.lightning_module.trainer)

def model_to_device(self) -> None:
log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
# ensure we update the device type in the lightning module
Expand Down
50 changes: 38 additions & 12 deletions pytorch_lightning/strategies/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_only

if _FAIRSCALE_AVAILABLE:
Expand All @@ -40,16 +41,41 @@ class DDPShardedStrategy(DDPStrategy):
strategy_name = "ddp_sharded"
_REDUCE_BUFFER_SIZE_DEFAULT: int = 2**23 # 8M

def configure_ddp(self) -> None:
trainer = self.lightning_module.trainer
if "reduce_buffer_size" not in self._ddp_kwargs:
# For multi-node training, enabling bucketing will improve performance.
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
def setup(self, trainer: "pl.Trainer") -> None:
# share ddp pids to all processes
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
if self._should_run_deadlock_detection():
self._share_information_to_prevent_deadlock()

self.accelerator.setup(trainer)

# move the model to the correct device
self.model_to_device()

# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = trainer.state.fn
if trainer_fn == TrainerFn.FITTING:
if self._layer_sync:
self.model = self._layer_sync.apply(self.model)

self.setup_precision_plugin()

if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()

def configure_ddp(self) -> None:
self._set_ddp_kwargs()
self.setup_optimizers(self.model.trainer)
self.model, self.optimizers = self._setup_model_and_optimizers(
model=LightningShardedDataParallel(self.model),
optimizers=trainer.optimizers,
optimizers=self.optimizers,
)
optimizers_to_device(self.optimizers, self.root_device)

def _set_ddp_kwargs(self) -> None:
if "reduce_buffer_size" not in self._ddp_kwargs:
# For multi-node training, enabling bucketing will improve performance.
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0

def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.
Expand All @@ -62,6 +88,12 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]
model = ShardedDataParallel(model, sharded_optimizer=optimizers, **self._ddp_kwargs)
return model, optimizers

def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
return optimizers

return self._reinit_optimizers_with_oss(optimizers)

def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]:
for x, optimizer in enumerate(optimizers):
if isinstance(optimizer, LightningOptimizer):
Expand All @@ -79,12 +111,6 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, Lightnin
del optimizer
return optimizers

def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
return optimizers

return self._reinit_optimizers_with_oss(optimizers)

def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/strategies/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_only

if _FAIRSCALE_AVAILABLE:
Expand All @@ -38,9 +39,12 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):
strategy_name = "ddp_sharded_spawn"

def configure_ddp(self) -> None:
# set up optimizers after the wrapped module has been moved to the device
self.setup_optimizers(self.lightning_module.trainer)
self.model, self.optimizers = self._setup_model_and_optimizers(
model=LightningShardedDataParallel(self.model), optimizers=self.optimizers
)
optimizers_to_device(self.optimizers, self.root_device)

def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.
Expand Down
11 changes: 6 additions & 5 deletions pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.strategies.launchers.xla_spawn import _XLASpawnLauncher
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.distributed import ReduceOp
Expand Down Expand Up @@ -126,9 +127,6 @@ def _configure_launcher(self):
def setup(self, trainer: "pl.Trainer") -> None:
self.start_method = "fork"
self.accelerator.setup(trainer)
self.setup_optimizers(trainer)
self.setup_precision_plugin()
optimizers_to_device(self.optimizers, self.root_device)

if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)
Expand All @@ -140,8 +138,11 @@ def setup(self, trainer: "pl.Trainer") -> None:
else:
set_shared_parameters(self.model.module, shared_params)

self.setup_optimizers(trainer)
self.precision_plugin.connect(self.model, None, None)
self.setup_precision_plugin()

if trainer.state.fn == TrainerFn.FITTING:
self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)

def _setup_model(self, model: Module) -> Module:
return model
Expand Down
6 changes: 3 additions & 3 deletions tests/strategies/test_bagua_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ def test_configuration(algorithm, tmpdir):
), mock.patch("bagua.torch_api.communication.is_initialized", return_value=True):
if algorithm == "qadam":
with pytest.raises(MisconfigurationException, match="Bagua QAdam can only accept one QAdamOptimizer"):
trainer.strategy.configure_ddp()
trainer.strategy._configure_bagua_model(trainer)
else:
trainer.strategy.configure_ddp()
trainer.strategy._configure_bagua_model(trainer)


@RunIf(min_gpus=1, bagua=True)
Expand All @@ -109,7 +109,7 @@ def test_qadam_configuration(tmpdir):
with mock.patch(
"bagua.torch_api.data_parallel.bagua_distributed.BaguaDistributedDataParallel.__init__", return_value=None
), mock.patch("bagua.torch_api.communication.is_initialized", return_value=True):
trainer.strategy.configure_ddp()
trainer.strategy._configure_bagua_model(trainer)


def test_bagua_not_available(monkeypatch):
Expand Down
Loading

0 comments on commit a5f82f5

Please sign in to comment.