diff --git a/CHANGELOG.md b/CHANGELOG.md index b9c16ffddccf9..0de3835f6cf7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -70,6 +70,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) +- Added `configure_sharded_model` hook ([#6679](https://github.com/PyTorchLightning/pytorch-lightning/pull/6679)) + + - Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index f4c068f298e89..7d16d91e3bf82 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union +import contextlib +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union import torch from torch.optim import Optimizer @@ -439,6 +440,18 @@ def results(self) -> Any: """ return self.training_type_plugin.results + @contextlib.contextmanager + def model_sharded_context(self) -> Generator: + """ + Provide hook to create modules in a distributed aware context. This is useful for when we'd like to + shard the model instantly - useful for extremely large models. Can save memory and + initialization time. + + Returns: Model parallel context. + """ + with self.training_type_plugin.model_sharded_context(): + yield + # todo: remove in v1.5 def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: """ @@ -468,4 +481,33 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: self.setup_precision_plugin(plugin) def save_checkpoint(self, checkpoint: Dict[str, Any], filepath) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + filepath: write-target file's path + """ self.training_type_plugin.save_checkpoint(checkpoint, filepath) + + @property + def call_configure_sharded_model_hook(self) -> bool: + """ + Allow model parallel hook to be called in suitable environments determined by the training type plugin. + This is useful for when we want to shard the model once within fit. + Returns: True if we want to call the model parallel setup hook. + """ + return self.training_type_plugin.call_configure_sharded_model_hook + + @call_configure_sharded_model_hook.setter + def call_configure_sharded_model_hook(self, mode: bool) -> None: + self.training_type_plugin.call_configure_sharded_model_hook = mode + + @property + def setup_optimizers_in_pre_dispatch(self) -> bool: + """ + Override to delay setting optimizers and schedulers till after dispatch. + This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model. + However this may break certain precision plugins such as APEX which require optimizers to be set. + Returns: If True, delay setup optimizers till pre_dispatch, else call within setup. + """ + return self.training_type_plugin.setup_optimizers_in_pre_dispatch diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 7757902bd3baf..768e4ebca30ee 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -29,6 +29,9 @@ class Callback(abc.ABC): Subclass this class and override any of the relevant hooks """ + def on_configure_sharded_model(self, trainer, pl_module: LightningModule) -> None: + """Called before configure sharded model""" + def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule) -> None: """Called before accelerator is being setup""" pass diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 58324e363cd37..a7485814b1b17 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -42,6 +42,7 @@ def __init__( self, on_before_accelerator_backend_setup: Optional[Callable] = None, setup: Optional[Callable] = None, + on_configure_sharded_model: Optional[Callable] = None, teardown: Optional[Callable] = None, on_init_start: Optional[Callable] = None, on_init_end: Optional[Callable] = None, @@ -83,6 +84,8 @@ def __init__( self.on_before_accelerator_backend_setup = on_before_accelerator_backend_setup if setup is not None: self.setup = setup + if on_configure_sharded_model is not None: + self.on_configure_sharded_model = on_configure_sharded_model if teardown is not None: self.teardown = teardown if on_init_start is not None: diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index bf3b0bf605679..b320a9b223840 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -310,6 +310,20 @@ def on_post_move_to_device(self): """ + def configure_sharded_model(self) -> None: + """ + Hook to create modules in a distributed aware context. This is useful for when using sharded plugins, + where we'd like to shard the model instantly, which is useful for extremely large models + which can save memory and initialization time. + + The accelerator manages whether to call this hook at every given stage. + For sharded plugins where model parallelism is required, the hook is usually on called once + to initialize the sharded parameters, and not called again in the same process. + + By default for accelerators/plugins that do not use model sharding techniques, + this hook is called during each fit/val/test/predict stages. + """ + class DataHooks: """Hooks to be used for data related stuff.""" diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index aee2b8914b579..ba074e7cfb206 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -300,9 +300,8 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: - trainer: PyTorch Lightning Trainer + checkpoint: dict containing model and trainer state filepath: write-target file's path - weights_only: saving model weights only """ # Todo: TypeError: 'mappingproxy' object does not support item assignment self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index ca097c32513c6..1eac88212e0fb 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Iterable, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Generator, Iterable, Optional, TYPE_CHECKING, Union import torch from torch.nn import Module @@ -35,6 +36,7 @@ class TrainingTypePlugin(Plugin, ABC): def __init__(self) -> None: self._model = None self._results = None + self._call_configure_sharded_model_hook = True def connect(self, model: 'Module') -> None: """Called by the accelerator to connect the accelerator and the model with this plugin""" @@ -196,6 +198,12 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: return False def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + filepath: write-target file's path + """ # dump states as a checkpoint dictionary object if self.is_global_zero: checkpoint = self.on_save(checkpoint) @@ -210,3 +218,27 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: f' An attribute is not picklable {err}' ) atomic_save(checkpoint, filepath) + + @contextlib.contextmanager + def model_sharded_context(self) -> Generator: + """ + Provide hook to create modules in a distributed aware context. This is useful for when we'd like to + shard the model instantly, which is useful for extremely large models which can save memory and + initialization time. + + Returns: Model parallel context. + """ + yield + + @property + def call_configure_sharded_model_hook(self) -> bool: + """ + Allow model parallel hook to be called in suitable environments determined by the training type plugin. + This is useful for when we want to shard the model once within fit. + Returns: True if we want to call the model parallel setup hook. + """ + return self._call_configure_sharded_model_hook + + @call_configure_sharded_model_hook.setter + def call_configure_sharded_model_hook(self, mode: bool) -> None: + self._call_configure_sharded_model_hook = mode diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 6d434e12a2e78..606f6b2e4b52b 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -38,6 +38,11 @@ def on_before_accelerator_backend_setup(self, model: LightningModule) -> None: for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, model) + def configure_sharded_model(self, model: LightningModule) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" + for callback in self.callbacks: + callback.on_configure_sharded_model(self, model) + def setup(self, model: LightningModule, stage: Optional[str]) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index 534dad5199e9b..87b730403b551 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -55,6 +55,11 @@ def _setup_log(): """Called when fit or test begins""" return None + @staticmethod + def _on_configure_sharded_model_log(): + """Called before configure sharded model""" + return None + @staticmethod def _teardown_log(): """Called at the end of fit and test""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fa02df7fb7ad1..78d9602f8e529 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -436,6 +436,7 @@ def fit( self.accelerator.connect(model) self.accelerator.setup_environment() self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment + self.call_configure_sharded_model(model) # allow user to setup in model sharded environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module # ---------------------------- @@ -1082,6 +1083,15 @@ def call_setup_hook(self, model: LightningModule) -> None: self.setup(model, stage=state) model.setup(stage=state) + def call_configure_sharded_model(self, model: LightningModule) -> None: + # Call configure sharded model hook if accelerator requests. In some cases + # we will not call the hook; the hook has initialized the sharded model for example. + if self.accelerator.call_configure_sharded_model_hook: + with self.accelerator.model_sharded_context(): + model.configure_sharded_model() + self.configure_sharded_model(model) + self.accelerator.call_configure_sharded_model_hook = False + def call_teardown_hook(self, model: LightningModule) -> None: state = self._teardown_state diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index bd8636ba839f9..2ad151d75e76c 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -1,9 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest import torch import tests.helpers.utils as tutils from pytorch_lightning import Trainer +from pytorch_lightning.plugins import SingleDevicePlugin from tests.accelerators.test_dp import CustomClassificationModelDP +from tests.helpers.boring_model import BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf @@ -44,3 +59,93 @@ def test_evaluate(tmpdir, trainer_kwargs): # make sure weights didn't change new_weights = model.layer_0.weight.clone().detach().cpu() torch.testing.assert_allclose(old_weights, new_weights) + + +def test_model_parallel_setup_called(tmpdir): + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.configure_sharded_model_called = False + self.layer = None + + def configure_sharded_model(self): + self.configure_sharded_model_called = True + self.layer = torch.nn.Linear(32, 2) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + ) + trainer.fit(model) + + assert model.configure_sharded_model_called + + +class DummyModel(BoringModel): + + def __init__(self): + super().__init__() + self.configure_sharded_model_called = False + + def configure_sharded_model(self): + self.configure_sharded_model_called = True + + +def test_configure_sharded_model_false(tmpdir): + """Ensure ``configure_sharded_model`` is not called, when turned off""" + + class CustomPlugin(SingleDevicePlugin): + + @property + def call_configure_sharded_model_hook(self) -> bool: + return False + + model = DummyModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + plugins=CustomPlugin(device=torch.device("cpu")) + ) + trainer.fit(model) + + assert not model.configure_sharded_model_called + + +def test_accelerator_configure_sharded_model_called_once(tmpdir): + """Ensure that the configure sharded model hook is called, and set to False after to ensure not called again.""" + + model = DummyModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + ) + assert trainer.accelerator.call_configure_sharded_model_hook is True + trainer.fit(model) + assert trainer.accelerator.call_configure_sharded_model_hook is False + + +def test_configure_sharded_model_called_once(tmpdir): + """Ensure ``configure_sharded_model`` is only called once""" + + model = DummyModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + ) + trainer.fit(model) + + assert model.configure_sharded_model_called + model.configure_sharded_model_called = False + + assert not model.configure_sharded_model_called diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 713971629bdf4..a30b4fe0f609b 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -48,6 +48,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'fit'), + call.on_configure_sharded_model(trainer, model), call.on_fit_start(trainer, model), call.on_pretrain_routine_start(trainer, model), call.on_pretrain_routine_end(trainer, model), @@ -119,6 +120,7 @@ def test_trainer_callback_hook_system_test(tmpdir): call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'test'), + call.on_configure_sharded_model(trainer, model), call.on_test_start(trainer, model), call.on_epoch_start(trainer, model), call.on_test_epoch_start(trainer, model), @@ -153,6 +155,7 @@ def test_trainer_callback_hook_system_validate(tmpdir): call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'validate'), + call.on_configure_sharded_model(trainer, model), call.on_validation_start(trainer, model), call.on_epoch_start(trainer, model), call.on_validation_epoch_start(trainer, model), diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 25103559cd070..c4ba371e6c561 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -281,6 +281,7 @@ def test_call_back_validator(tmpdir): 'on_epoch_end', 'on_epoch_start', 'on_fit_end', + 'on_configure_sharded_model', 'on_fit_start', 'on_init_end', 'on_init_start', @@ -317,6 +318,7 @@ def test_call_back_validator(tmpdir): "on_before_accelerator_backend_setup", "on_fit_end", "on_fit_start", + "on_configure_sharded_model", "on_init_end", "on_init_start", "on_keyboard_interrupt",