diff --git a/.gitignore b/.gitignore index 390551b8f6e60c..cd0ba22453512d 100644 --- a/.gitignore +++ b/.gitignore @@ -155,3 +155,5 @@ cifar-10-batches-py # ctags tags data +MNIST +runs diff --git a/CHANGELOG.md b/CHANGELOG.md index ef33b0b06890a5..8ab18a66d37f5f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -108,6 +108,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `LightningDataModule.from_datasets(...)` ([#5133](https://github.com/PyTorchLightning/pytorch-lightning/pull/5133)) +- Added `PL_TORCH_DISTRIBUTED_BACKEND` env variable to select backend ([#5981](https://github.com/PyTorchLightning/pytorch-lightning/pull/5981)) + + ### Changed - Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) @@ -288,9 +291,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed passing wrong strings for scheduler interval doesn't throw an error ([#5923](https://github.com/PyTorchLightning/pytorch-lightning/pull/5923)) +- Fixed add `on_epoch_end` hook at the end of `validation`, `test` epoch ([#5986](https://github.com/PyTorchLightning/pytorch-lightning/pull/5986)) + + - Fixed missing `process_dataloader` call for `TPUSpawn` when in distributed mode ([#6015](https://github.com/PyTorchLightning/pytorch-lightning/pull/6015)) +- Fixed progress bar flickering by appending 0 to floats/strings ([#6009](https://github.com/PyTorchLightning/pytorch-lightning/pull/6009)) + + +- Fixed synchronization issues with TPU training ([#6027](https://github.com/PyTorchLightning/pytorch-lightning/pull/6027)) + + ## [1.1.8] - 2021-02-08 ### Fixed diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 19e552f3511354..69470dbff1a1da 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -239,6 +239,20 @@ Note in particular the difference between `gpus=0`, `gpus=[0]` and `gpus="0"`. to be in "exclusive mode", such that only one process at a time can access them. For more details see the :doc:`trainer guide <../common/trainer>`. + +Select torch distributed backend +-------------------------------- + +By default, Lightning will select the ``nccl`` backend over ``gloo`` when running on GPUs. +Find more information about PyTorch's supported backends `here `__. + +Lightning exposes an environment variable ``PL_TORCH_DISTRIBUTED_BACKEND`` for the user to change the backend. + +.. code-block:: bash + + PL_TORCH_DISTRIBUTED_BACKEND=gloo python train.py ... + + ---------- Distributed modes diff --git a/docs/source/benchmarking/performance.rst b/docs/source/benchmarking/performance.rst index 951c01943e071c..5f89c759e49bc7 100644 --- a/docs/source/benchmarking/performance.rst +++ b/docs/source/benchmarking/performance.rst @@ -135,8 +135,34 @@ Refer to the :doc:`distributed computing guide for more details <../advanced/mul Sequential Model Parallelism with Checkpointing ---------------------------------------------------------------------- +----------------------------------------------- PyTorch Lightning integration for Sequential Model Parallelism using `FairScale `_. Sequential Model Parallelism splits a sequential module onto multiple GPUs, reducing peak GPU memory requirements substantially. For more information, refer to :ref:`sequential-parallelism`. + + +Preload Data Into RAM +--------------------- + +When your training or preprocessing requires many operations to be performed on entire dataset(s) it can +sometimes be beneficial to store all data in RAM given there is enough space. +However, loading all data at the beginning of the training script has the disadvantage that it can take a long +time and hence it slows down the development process. Another downside is that in multiprocessing (e.g. DDP) +the data would get copied in each process. +One can overcome these problems by copying the data into RAM in advance. +Most UNIX-based operating systems provide direct access to tmpfs through a mount point typically named ``/dev/shm``. + +0. Increase shared memory if necessary. Refer to the documentation of your OS how to do this. + +1. Copy training data to shared memory: + + .. code-block:: bash + + cp -r /path/to/data/on/disk /dev/shm/ + +2. Refer to the new data root in your script or command line arguments: + + .. code-block:: python + + datamodule = MyDataModule(data_root="/dev/shm/my_data") diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index cfd525cf61f873..d568fd525b25fa 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union import torch from torch.optim import Optimizer +from torch.utils.data import DataLoader from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin @@ -365,3 +366,11 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s A tensor of shape (world_size, batch, ...) """ return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + + def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: + """Wraps the dataloader if necessary + + Args: + dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` + """ + return self.training_type_plugin.process_dataloader(dataloader) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index a8024bef2a5394..83d86b619c7c94 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -554,6 +554,14 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics): epoch = metrics.get("epoch") step = metrics.get("step") + # when `val_loss` is being logged and no ModelCheckpoint is being provided + # `val_loss` will be selected for monitor and need to be reduced to + # prevent processes divergence + # TODO: Move this logic to logger_connector. This also needs to be fixed for any + # other monitor logged value which aren't produced from a Metric. + if self.monitor == "val_loss": + current = trainer.training_type_plugin.reduce(current, reduce_op="mean") + if self.check_monitor_top_k(current): self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics) elif self.verbose: diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 7de7982b4a2de3..3f401669c351e5 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -26,12 +26,37 @@ from typing import Optional, Union if importlib.util.find_spec('ipywidgets') is not None: - from tqdm.auto import tqdm + from tqdm.auto import tqdm as _tqdm else: - from tqdm import tqdm + from tqdm import tqdm as _tqdm from pytorch_lightning.callbacks import Callback +_PAD_SIZE = 5 + + +class tqdm(_tqdm): + """ + Custom tqdm progressbar where we append 0 to floating points/strings to + prevent the progress bar from flickering + """ + + @staticmethod + def format_num(n) -> str: + """ Add additional padding to the formatted numbers """ + should_be_padded = isinstance(n, (float, str)) + if not isinstance(n, str): + n = _tqdm.format_num(n) + if should_be_padded and 'e' not in n: + if '.' not in n and len(n) < _PAD_SIZE: + try: + _ = float(n) + except ValueError: + return n + n += '.' + n += "0" * (_PAD_SIZE - len(n)) + return n + class ProgressBarBase(Callback): r""" diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 906c7ab025d898..f018b47533f4a1 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -210,11 +210,10 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None: os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) - torch_backend = "nccl" if self.on_gpu else "gloo" if not torch.distributed.is_initialized(): log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") - torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size) + torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) def pre_dispatch(self): # TODO: check if needed diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 449373e2c35eab..1265ba945f1622 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -194,11 +194,10 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None: os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) - torch_backend = "nccl" if self.on_gpu else "gloo" if not torch.distributed.is_initialized(): log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") - torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size) + torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) def determine_ddp_device_ids(self): if self.root_device.type == "cpu": diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 3d6d661f171ee0..c27c9705cd3629 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import io +import os from abc import ABC, abstractmethod from contextlib import contextmanager from typing import List, Optional @@ -23,6 +24,7 @@ from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin +from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp @@ -82,6 +84,13 @@ def reduce_early_stopping_decision(self, should_stop: bool) -> bool: should_stop = bool(should_stop == self.world_size) return should_stop + @property + def torch_distributed_backend(self): + torch_backend = os.getenv("PL_TORCH_DISTRIBUTED_BACKEND") + if torch_backend is None: + torch_backend = "nccl" if self.on_gpu else "gloo" + return torch_backend + @staticmethod def configure_sync_batchnorm(model: LightningModule) -> LightningModule: """ diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 5aed1bfaa10e8b..83a441aabad308 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -10,7 +10,8 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn -from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything if _TPU_AVAILABLE: @@ -46,10 +47,6 @@ def create_mp_queue(self): def distributed_sampler_kwargs(self) -> dict: return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) - @property - def should_finalize(self): - return self.world_size == 1 - @property def is_distributed(self): return self.world_size != 1 @@ -179,6 +176,24 @@ def reduce_early_stopping_decision(self, should_stop: bool) -> bool: should_stop = int(stop.item()) == self.world_size return should_stop + def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): + if not isinstance(output, torch.Tensor): + output = torch.tensor(output, device=self.device) + + _invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM + _invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") + if _invalid_reduce_op or _invalid_reduce_op_str: + raise MisconfigurationException( + "Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation." + ) + + output = xm.mesh_reduce('reduce', output, sum) + + if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): + output = output / self.world_size + + return output + def post_dispatch(self) -> None: # TODO: Check if trainer references can be resolved otherwise model = self.lightning_module @@ -213,6 +228,10 @@ def __load_weights_on_main_process(self) -> None: self._model = model + def _close_logger(self, trainer) -> None: + if hasattr(trainer, "logger"): + trainer.logger.finalize("success") + @property def xmp_spawn_kwargs(self): return { @@ -225,9 +244,11 @@ def start_training(self, trainer) -> None: # todo: precision pluging is call in accelerator setup and should be moved if 'XLA_USE_BF16' in os.environ: del os.environ["XLA_USE_BF16"] + self._close_logger(trainer) xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) def start_testing(self, trainer) -> None: + self._close_logger(trainer) xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) def start_predicting(self, trainer) -> None: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index cede3e5f98b435..938a17249e9f65 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -35,10 +35,6 @@ def __init__(self) -> None: self._results = None self.global_rank = 0 - @property - def should_finalize(self): - return True - @property @abstractmethod def on_gpu(self) -> bool: diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py similarity index 99% rename from pytorch_lightning/accelerators/accelerator_connector.py rename to pytorch_lightning/trainer/connectors/accelerator_connector.py index 644b382b6bba21..2c4eafb6ed0e85 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -60,7 +60,7 @@ import horovod.torch as hvd -class BackendConnector(object): +class AcceleratorConnector(object): def __init__( self, diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 595a5e84bf6305..89c72883fc497b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -14,7 +14,7 @@ import os from copy import deepcopy from pprint import pprint -from typing import Dict, Iterable, Union +from typing import Dict, Iterable, Optional, Union import torch @@ -32,7 +32,7 @@ class LoggerConnector: - def __init__(self, trainer, log_gpu_memory: bool): + def __init__(self, trainer, log_gpu_memory: Optional[str] = None): self.trainer = trainer self.log_gpu_memory = log_gpu_memory self._callback_metrics = MetricsHolder() diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index a6aeeb7d73f786..71b557bf75a2c0 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.accelerators.accelerator_connector import BackendConnector +from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn @@ -22,7 +22,7 @@ class DeprecatedDistDeviceAttributes: _device_type: DeviceType _running_stage: RunningStage num_gpus: int - accelerator_connector: BackendConnector + accelerator_connector: AcceleratorConnector @property def on_cpu(self) -> bool: diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index fe3fc62ff11897..053c4ea5ae3603 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -318,6 +318,8 @@ def on_evaluation_epoch_end(self, *args, **kwargs): else: self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) + self.trainer.call_hook('on_epoch_end') + def log_evaluation_step_metrics(self, output, batch_idx): if self.trainer.running_sanity_check: return diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 1f0cc52870f7e7..282b4539df0bee 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -21,7 +21,7 @@ from torch.optim import Optimizer from pytorch_lightning.accelerators import Accelerator -from pytorch_lightning.accelerators.accelerator_connector import BackendConnector +from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule @@ -51,7 +51,7 @@ class TrainerProperties(ABC): _state: TrainerState _weights_save_path: str - accelerator_connector: BackendConnector + accelerator_connector: AcceleratorConnector callbacks: List[Callback] checkpoint_connector: CheckpointConnector limit_val_batches: int @@ -373,6 +373,11 @@ def optimizers(self) -> Optional[List[Optimizer]]: @optimizers.setter def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: + # Necessary to rewrap optimizers to lightning + # They will be re-created when accessing + # the `lightning_optimizers` trainer property + self._lightning_optimizers = None + self.accelerator.optimizers = new_optims @property diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2453a08ba9067b..46ca290b24d349 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -22,7 +22,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.accelerators import Accelerator -from pytorch_lightning.accelerators.accelerator_connector import BackendConnector +from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule @@ -308,7 +308,7 @@ def __init__( self.data_connector = DataConnector(self) self.optimizer_connector = OptimizerConnector(self) - self.accelerator_connector = BackendConnector( + self.accelerator_connector = AcceleratorConnector( num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, sync_batchnorm, benchmark, replace_sampler_ddp, deterministic, precision, amp_backend, amp_level, plugins ) @@ -711,7 +711,7 @@ def run_evaluation(self, max_batches=None, on_epoch=False): for dataloader_idx, dataloader in enumerate(dataloaders): # bookkeeping dl_outputs = [] - dataloader = self.training_type_plugin.process_dataloader(dataloader) + dataloader = self.accelerator.process_dataloader(dataloader) dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] for batch_idx, batch in enumerate(dataloader): @@ -823,7 +823,7 @@ def run_predict(self): # run validation/testing for dataloader_idx, dataloader in enumerate(dataloaders): - dataloader = self.training_type_plugin.process_dataloader(dataloader) + dataloader = self.accelerator.process_dataloader(dataloader) dl_max_batches = self.predict_loop.max_batches[dataloader_idx] for batch_idx, batch in enumerate(dataloader): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 0908e96bd1c179..57c0b10f124124 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -140,7 +140,7 @@ def on_train_end(self): # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers. # It might be related to xla tensors blocked when moving the cpu # kill loggers - if self.trainer.logger is not None and self.trainer.training_type_plugin.should_finalize: + if self.trainer.logger is not None: self.trainer.logger.finalize("success") # summarize profile results @@ -502,7 +502,7 @@ def tbptt_split_batch(self, batch): def run_training_epoch(self): # modify dataloader if needed (ddp, etc...) - train_dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader) + train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) # track epoch output epoch_output = [[] for _ in range(self.num_optimizers)] diff --git a/tests/accelerators/test_ddp.py b/tests/accelerators/test_ddp.py index 1bd9e36d4f2053..b582532cd710ea 100644 --- a/tests/accelerators/test_ddp.py +++ b/tests/accelerators/test_ddp.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import platform +from unittest.mock import patch import pytest import torch +from pytorch_lightning import Trainer from tests.accelerators import ddp_model, DDPLauncher +from tests.helpers.boring_model import BoringModel from tests.utilities.distributed import call_training_script @@ -83,3 +87,25 @@ def test_cli_to_pass(tmpdir, args=None): This test verify we can call function using test_cli name """ return '1' + + +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(torch.cuda.is_available(), reason="test doesn't requires GPU machine") +def test_torch_distributed_backend_env_variables(tmpdir): + """ + This test set `undefined` as torch backend and should raise an `Backend.UNDEFINED` ValueError. + """ + _environ = {"PL_TORCH_DISTRIBUTED_BACKEND": "undefined", "CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2"} + with patch.dict(os.environ, _environ), \ + patch('torch.cuda.device_count', return_value=2): + + with pytest.raises(ValueError, match="Invalid backend: 'undefined'"): + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + accelerator="ddp", + gpus=2, + logger=False, + ) + trainer.fit(model) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 8bb6d3c8dc8154..379bc79263a6ec 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -65,6 +65,7 @@ def test_trainer_callback_system(torch_save, tmpdir): call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_validation_epoch_end(trainer, model), + call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_sanity_check_end(trainer, model), call.on_train_start(trainer, model), @@ -95,6 +96,7 @@ def test_trainer_callback_system(torch_save, tmpdir): call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_validation_epoch_end(trainer, model), + call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_save_checkpoint(trainer, model), call.on_train_end(trainer, model), @@ -119,6 +121,7 @@ def test_trainer_callback_system(torch_save, tmpdir): call.on_test_batch_start(trainer, model, ANY, 1, 0), call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0), call.on_test_epoch_end(trainer, model), + call.on_epoch_end(trainer, model), call.on_test_end(trainer, model), call.on_fit_end(trainer, model), call.teardown(trainer, model, 'fit'), diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 8398aec88fe68e..9ec48008512fb8 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -20,6 +20,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase +from pytorch_lightning.callbacks.progress import tqdm from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel @@ -371,3 +372,12 @@ def training_step(self, batch, batch_idx): pbar = trainer.progress_bar_callback.main_progress_bar actual = str(pbar.postfix) assert actual.endswith("foo=0.123, bar={'baz': tensor([1])}") + + +@pytest.mark.parametrize( + "input_num, expected", [[1, '1'], [1.0, '1.000'], [0.1, '0.100'], [1e-3, '0.001'], [1e-5, '1e-5'], ['1.0', '1.000'], + ['10000', '10000'], ['abc', 'abc']] +) +def test_tqdm_format_num(input_num, expected): + """ Check that the specialized tqdm.format_num appends 0 to floats and strings """ + assert tqdm.format_num(input_num) == expected diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 057512be31af27..4067274165fe9c 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -420,6 +420,7 @@ def teardown(self, stage: str): 'on_validation_batch_start', 'on_validation_batch_end', 'on_validation_epoch_end', + 'on_epoch_end', 'on_validation_end', 'on_validation_model_train', 'on_train_start', @@ -441,6 +442,7 @@ def teardown(self, stage: str): 'on_validation_batch_start', 'on_validation_batch_end', 'on_validation_epoch_end', + 'on_epoch_end', 'on_save_checkpoint', 'on_validation_end', 'on_validation_model_train', @@ -462,6 +464,7 @@ def teardown(self, stage: str): 'on_test_batch_start', 'on_test_batch_end', 'on_test_epoch_end', + 'on_epoch_end', 'on_test_end', 'on_test_model_train', 'on_fit_end', diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index d9ea8a9917d2b0..4c6620b07b74a3 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -26,6 +26,7 @@ from pytorch_lightning.plugins import TPUSpawnPlugin from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset from tests.helpers.utils import pl_multi_process_test @@ -264,9 +265,6 @@ def test_distributed_backend_set_when_using_tpu(tmpdir, tpu_cores): @pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") -@pytest.mark.skipif( - not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" -) @pl_multi_process_test def test_broadcast_on_tpu(): """ Checks if an object from the master process is broadcasted to other processes correctly""" @@ -327,3 +325,26 @@ def test_tpu_cores_with_argparse(cli_args, expected): for k, v in expected.items(): assert getattr(args, k) == v assert Trainer.from_argparse_args(args) + + +@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@pl_multi_process_test +def test_tpu_reduce(): + """Test tpu spawn reduce operation """ + + def test_reduce(rank): + trainer = Trainer(tpu_cores=8) + # faster this way + reduce_ops = ["mean", "AVG", "undefined", "sum", ReduceOp.SUM, ReduceOp.MAX] + for reduce_op in reduce_ops: + if reduce_op == "undefined" or reduce_op == ReduceOp.MAX: + with pytest.raises(MisconfigurationException, match="TPUSpawn TrainingTypePlugin only support"): + result = trainer.training_type_plugin.reduce(1, reduce_op) + else: + result = trainer.training_type_plugin.reduce(1, reduce_op) + if isinstance(reduce_op, str) and reduce_op.lower() in ("mean", "avg"): + assert result.item() == 1 + else: + assert result.item() == 8 + + xmp.spawn(test_reduce, nprocs=8, start_method='fork') diff --git a/tests/trainer/logging_/test_eval_loop_logging_1_0.py b/tests/trainer/logging_/test_eval_loop_logging_1_0.py index a635a1dfe35370..46890f68017111 100644 --- a/tests/trainer/logging_/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_/test_eval_loop_logging_1_0.py @@ -539,9 +539,10 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, self.count += 1 def on_epoch_end(self, trainer, pl_module): - self.make_logging( - pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices - ) + if not trainer.training: + self.make_logging( + pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices + ) def on_validation_epoch_end(self, trainer, pl_module): self.make_logging( @@ -584,8 +585,8 @@ def validation_step(self, batch, batch_idx): assert test_callback.funcs_called_count["on_validation_start"] == 1 assert test_callback.funcs_called_count["on_validation_epoch_start"] == 1 # assert test_callback.funcs_called_count["on_validation_batch_start"] == 4 - assert test_callback.funcs_called_count["on_validation_batch_end"] == 4 assert test_callback.funcs_called_count["on_epoch_end"] == 1 + assert test_callback.funcs_called_count["on_validation_batch_end"] == 4 assert test_callback.funcs_called_count["on_validation_epoch_end"] == 1 # Make sure the func_name exists within callback_metrics. If not, we missed some @@ -705,11 +706,6 @@ def on_test_start(self, trainer, pl_module): pl_module, 'on_test_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices ) - def on_epoch_start(self, trainer, pl_module): - self.make_logging( - pl_module, 'on_epoch_start', 2, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices - ) - def on_test_epoch_start(self, trainer, pl_module): self.make_logging( pl_module, @@ -735,11 +731,6 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal # with func = np.mean if on_epoch else func = np.max self.count += 1 - def on_epoch_end(self, trainer, pl_module): - self.make_logging( - pl_module, 'on_epoch_end', 6, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices - ) - def on_test_epoch_end(self, trainer, pl_module): self.make_logging( pl_module, 'on_test_epoch_end', 7, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices @@ -775,11 +766,9 @@ def test_dataloader(self): max_epochs=max_epochs, callbacks=[test_callback], ) - trainer.fit(model) - trainer.test() + trainer.test(model) assert test_callback.funcs_called_count["on_test_start"] == 1 - assert test_callback.funcs_called_count["on_epoch_start"] == 2 assert test_callback.funcs_called_count["on_test_epoch_start"] == 1 assert test_callback.funcs_called_count["on_test_batch_end"] == 4 assert test_callback.funcs_called_count["on_test_epoch_end"] == 1 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 71caaaad4d7f97..ce2eeb43e01149 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -540,12 +540,12 @@ def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_serve class TestModel(BoringModel): # Model that tracks epochs and batches seen - num_epochs_seen = 0 + num_epochs_end_seen = 0 num_batches_seen = 0 num_on_load_checkpoint_called = 0 def on_epoch_end(self): - self.num_epochs_seen += 1 + self.num_epochs_end_seen += 1 def on_train_batch_start(self, *_): self.num_batches_seen += 1 @@ -567,7 +567,8 @@ def on_load_checkpoint(self, _): ) trainer.fit(model) - assert model.num_epochs_seen == 2 + # `on_epoch_end` will be called once for val_sanity, twice for train, twice for val + assert model.num_epochs_end_seen == 1 + 2 + 2 assert model.num_batches_seen == trainer.num_training_batches * 2 assert model.num_on_load_checkpoint_called == 0