diff --git a/CHANGELOG.md b/CHANGELOG.md
index 3e1ada8db632c..87ef042af77ee 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -63,7 +63,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the deprecated `progress_bar_refresh_rate` argument from the `Trainer` constructor ([#12514](https://github.com/PyTorchLightning/pytorch-lightning/pull/12514))
--
+- Removed the deprecated `prepare_data_per_node` argument from the `Trainer` constructor ([#12536](https://github.com/PyTorchLightning/pytorch-lightning/pull/12536))
-
diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst
index 64aad4639f04d..40efa5fc8d7f8 100644
--- a/docs/source/common/trainer.rst
+++ b/docs/source/common/trainer.rst
@@ -1171,39 +1171,6 @@ To define your own behavior, subclass the relevant class and pass it in. Here's
trainer = Trainer(plugins=[MyCluster()], ...)
-
-prepare_data_per_node
-^^^^^^^^^^^^^^^^^^^^^
-.. warning:: ``prepare_data_per_node`` has been deprecated in v1.5 and will be removed in v1.7.
- Please set its value inside ``LightningDataModule`` and/or ``LightningModule`` directly described
- in the following code:
-
- .. testcode::
-
- class LitDataModule(LightningDataModule):
- def __init__(self):
- super().__init__()
- self.prepare_data_per_node = True
-
-.. raw:: html
-
-
-
-|
-
-If set to ``True`` will call ``prepare_data()`` on LOCAL_RANK=0 for every node.
-If set to ``False`` will only call from NODE_RANK=0, LOCAL_RANK=0.
-
-.. testcode::
-
- # default
- Trainer(prepare_data_per_node=True)
-
- # use only NODE_RANK=0, LOCAL_RANK=0
- Trainer(prepare_data_per_node=False)
-
precision
^^^^^^^^^
diff --git a/docs/source/starter/core_guide.rst b/docs/source/starter/core_guide.rst
index 42bb74290c639..20ec33534ac71 100644
--- a/docs/source/starter/core_guide.rst
+++ b/docs/source/starter/core_guide.rst
@@ -596,8 +596,7 @@ will cause all sorts of issues.
To solve this problem, make sure your download code is in the ``prepare_data`` method in the DataModule.
In this method we do all the preparation we need to do once (instead of on every GPU).
-``prepare_data`` can be called in two ways, once per node or only on the root node
-(``Trainer(prepare_data_per_node=False)``).
+``prepare_data`` can be called in two ways, once per node or only on the root node.
.. code-block:: python
diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py
index 442da2274c360..319b41bd08e25 100644
--- a/pytorch_lightning/core/hooks.py
+++ b/pytorch_lightning/core/hooks.py
@@ -361,21 +361,27 @@ def prepare_data(self):
self.split = data_split
self.some_state = some_other_state()
- In DDP ``prepare_data`` can be called in two ways (using Trainer(prepare_data_per_node)):
+ In a distributed environment, ``prepare_data`` can be called in two ways
+ (using :ref:`prepare_data_per_node`)
1. Once per node. This is the default and is only called on LOCAL_RANK=0.
2. Once in total. Only called on GLOBAL_RANK=0.
- See :ref:`prepare_data_per_node`.
-
Example::
# DEFAULT
# called once per node on LOCAL_RANK=0 of that node
- Trainer(prepare_data_per_node=True)
+ class LitDataModule(LightningDataModule):
+ def __init__(self):
+ super().__init__()
+ self.prepare_data_per_node = True
+
# call on GLOBAL_RANK=0 (great for shared file systems)
- Trainer(prepare_data_per_node=False)
+ class LitDataModule(LightningDataModule):
+ def __init__(self):
+ super().__init__()
+ self.prepare_data_per_node = False
This is called before requesting the dataloaders:
@@ -387,6 +393,7 @@ def prepare_data(self):
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
+ model.predict_dataloader()
"""
def setup(self, stage: Optional[str] = None) -> None:
diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py
index 58f730f43eef1..ccf3710767e87 100644
--- a/pytorch_lightning/trainer/connectors/data_connector.py
+++ b/pytorch_lightning/trainer/connectors/data_connector.py
@@ -39,7 +39,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
-from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
+from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from pytorch_lightning.utilities.warnings import PossibleUserWarning, WarningCache
@@ -73,18 +73,9 @@ def on_trainer_init(
self,
check_val_every_n_epoch: int,
reload_dataloaders_every_n_epochs: int,
- prepare_data_per_node: Optional[bool] = None,
) -> None:
self.trainer.datamodule = None
- if prepare_data_per_node is not None:
- rank_zero_deprecation(
- "Setting `prepare_data_per_node` with the trainer flag is deprecated in v1.5.0 and will be removed in"
- " v1.7.0. Please set `prepare_data_per_node` in `LightningDataModule` and/or `LightningModule`"
- " directly instead."
- )
- self.trainer.prepare_data_per_node = prepare_data_per_node
-
if not isinstance(check_val_every_n_epoch, int):
raise MisconfigurationException(
f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}"
@@ -112,28 +103,12 @@ def prepare_data(self) -> None:
# check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data
if datamodule is not None:
dm_prepare_data_per_node = datamodule.prepare_data_per_node
- dm_eq_prepare_data = datamodule.prepare_data_per_node == self.trainer.prepare_data_per_node
- if self.trainer.prepare_data_per_node is not None and not dm_eq_prepare_data:
- raise MisconfigurationException(
- "Inconsistent settings found for `prepare_data_per_node`."
- f" Value was set with both `Trainer(prepare_data_per_node={self.trainer.prepare_data_per_node}.)`"
- f" and `DataModule.prepare_data_per_node={datamodule.prepare_data_per_node}`."
- " Move `prepare_data_per_node` setting to DataModule property."
- )
if (dm_prepare_data_per_node and local_rank_zero) or (not dm_prepare_data_per_node and global_rank_zero):
self.trainer.datamodule.prepare_data()
# handle lightning module prepare data:
# check for prepare_data_per_node before calling lightning_module.prepare_data
if lightning_module is not None:
lm_prepare_data_per_node = lightning_module.prepare_data_per_node
- lm_eq_prepare_data = lightning_module.prepare_data_per_node == self.trainer.prepare_data_per_node
- if (self.trainer.prepare_data_per_node is not None) and not lm_eq_prepare_data:
- raise MisconfigurationException(
- "Inconsistent settings found for `prepare_data_per_node`."
- f" Value was set with both `Trainer(prepare_data_per_node={self.trainer.prepare_data_per_node}.)`"
- f" and `LightningModule.prepare_data_per_node={lightning_module.prepare_data_per_node}`."
- " Move `prepare_data_per_node` setting to LightningModule property."
- )
if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero):
self.trainer._call_lightning_module_hook("prepare_data")
self.trainer._is_data_prepared = True
diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py
index 309e481e1045b..75d31325ca8fe 100644
--- a/pytorch_lightning/trainer/trainer.py
+++ b/pytorch_lightning/trainer/trainer.py
@@ -181,7 +181,6 @@ def __init__(
replace_sampler_ddp: bool = True,
detect_anomaly: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
- prepare_data_per_node: Optional[bool] = None,
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
amp_backend: str = "native",
amp_level: Optional[str] = None,
@@ -314,14 +313,6 @@ def __init__(
log_every_n_steps: How often to log within steps.
Default: ``50``.
- prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
- Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data
-
- .. deprecated:: v1.5
- Deprecated in v1.5.0 and will be removed in v1.7.0
- Please set ``prepare_data_per_node`` in ``LightningDataModule`` and/or
- ``LightningModule`` directly instead.
-
process_position: Orders the progress bar when running multiple models on same machine.
.. deprecated:: v1.5
@@ -542,7 +533,6 @@ def __init__(
self._data_connector.on_trainer_init(
check_val_every_n_epoch,
reload_dataloaders_every_n_epochs,
- prepare_data_per_node,
)
if terminate_on_nan is not None:
diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py
index 4c337a44ca14e..9ffa443a809e2 100644
--- a/tests/core/test_datamodules.py
+++ b/tests/core/test_datamodules.py
@@ -487,14 +487,3 @@ class BoringDataModule2(LightningDataModule):
assert hasattr(BoringDataModule2, "__repr__")
assert BoringDataModule2(batch_size=32).prepare_data() is None
assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32)
-
-
-def test_inconsistent_prepare_data_per_node(tmpdir):
- with pytest.raises(MisconfigurationException, match="Inconsistent settings found for `prepare_data_per_node`."):
- model = BoringModel()
- dm = BoringDataModule()
- with pytest.deprecated_call(match="prepare_data_per_node` with the trainer flag is deprecated"):
- trainer = Trainer(prepare_data_per_node=False)
- trainer.model = model
- trainer.datamodule = dm
- trainer._data_connector.prepare_data()
diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py
index ecd890e6b6291..fdc683edd23a2 100644
--- a/tests/deprecated_api/test_remove_1-7.py
+++ b/tests/deprecated_api/test_remove_1-7.py
@@ -125,11 +125,6 @@ def get_progress_bar_dict(self):
_ = trainer.progress_bar_dict
-def test_v1_7_0_trainer_prepare_data_per_node(tmpdir):
- with pytest.deprecated_call(match="Setting `prepare_data_per_node` with the trainer flag is deprecated in v1.5.0"):
- _ = Trainer(prepare_data_per_node=False)
-
-
@pytest.mark.parametrize("terminate_on_nan", [True, False])
def test_v1_7_0_trainer_terminate_on_nan(tmpdir, terminate_on_nan):
with pytest.deprecated_call(