Skip to content

Commit

Permalink
Merge branch 'master' into refactor/flake8-plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Feb 17, 2021
2 parents 0fc0e8b + 15d6788 commit c5e5785
Show file tree
Hide file tree
Showing 26 changed files with 233 additions and 53 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,5 @@ cifar-10-batches-py
# ctags
tags
data
MNIST
runs
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `LightningDataModule.from_datasets(...)` ([#5133](https://github.com/PyTorchLightning/pytorch-lightning/pull/5133))


- Added `PL_TORCH_DISTRIBUTED_BACKEND` env variable to select backend ([#5981](https://github.com/PyTorchLightning/pytorch-lightning/pull/5981))


### Changed

- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
Expand Down Expand Up @@ -288,9 +291,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed passing wrong strings for scheduler interval doesn't throw an error ([#5923](https://github.com/PyTorchLightning/pytorch-lightning/pull/5923))


- Fixed add `on_epoch_end` hook at the end of `validation`, `test` epoch ([#5986](https://github.com/PyTorchLightning/pytorch-lightning/pull/5986))


- Fixed missing `process_dataloader` call for `TPUSpawn` when in distributed mode ([#6015](https://github.com/PyTorchLightning/pytorch-lightning/pull/6015))


- Fixed progress bar flickering by appending 0 to floats/strings ([#6009](https://github.com/PyTorchLightning/pytorch-lightning/pull/6009))


- Fixed synchronization issues with TPU training ([#6027](https://github.com/PyTorchLightning/pytorch-lightning/pull/6027))


## [1.1.8] - 2021-02-08

### Fixed
Expand Down
14 changes: 14 additions & 0 deletions docs/source/advanced/multi_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,20 @@ Note in particular the difference between `gpus=0`, `gpus=[0]` and `gpus="0"`.
to be in "exclusive mode", such that only one process at a time can access them.
For more details see the :doc:`trainer guide <../common/trainer>`.


Select torch distributed backend
--------------------------------

By default, Lightning will select the ``nccl`` backend over ``gloo`` when running on GPUs.
Find more information about PyTorch's supported backends `here <https://pytorch.org/docs/stable/distributed.html>`__.

Lightning exposes an environment variable ``PL_TORCH_DISTRIBUTED_BACKEND`` for the user to change the backend.

.. code-block:: bash
PL_TORCH_DISTRIBUTED_BACKEND=gloo python train.py ...
----------

Distributed modes
Expand Down
28 changes: 27 additions & 1 deletion docs/source/benchmarking/performance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,34 @@ Refer to the :doc:`distributed computing guide for more details <../advanced/mul


Sequential Model Parallelism with Checkpointing
---------------------------------------------------------------------
-----------------------------------------------
PyTorch Lightning integration for Sequential Model Parallelism using `FairScale <https://github.com/facebookresearch/fairscale>`_.
Sequential Model Parallelism splits a sequential module onto multiple GPUs, reducing peak GPU memory requirements substantially.

For more information, refer to :ref:`sequential-parallelism`.


Preload Data Into RAM
---------------------

When your training or preprocessing requires many operations to be performed on entire dataset(s) it can
sometimes be beneficial to store all data in RAM given there is enough space.
However, loading all data at the beginning of the training script has the disadvantage that it can take a long
time and hence it slows down the development process. Another downside is that in multiprocessing (e.g. DDP)
the data would get copied in each process.
One can overcome these problems by copying the data into RAM in advance.
Most UNIX-based operating systems provide direct access to tmpfs through a mount point typically named ``/dev/shm``.

0. Increase shared memory if necessary. Refer to the documentation of your OS how to do this.

1. Copy training data to shared memory:

.. code-block:: bash
cp -r /path/to/data/on/disk /dev/shm/
2. Refer to the new data root in your script or command line arguments:

.. code-block:: python
datamodule = MyDataModule(data_root="/dev/shm/my_data")
11 changes: 10 additions & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union

import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
Expand Down Expand Up @@ -365,3 +366,11 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Wraps the dataloader if necessary
Args:
dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader`
"""
return self.training_type_plugin.process_dataloader(dataloader)
8 changes: 8 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,14 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
epoch = metrics.get("epoch")
step = metrics.get("step")

# when `val_loss` is being logged and no ModelCheckpoint is being provided
# `val_loss` will be selected for monitor and need to be reduced to
# prevent processes divergence
# TODO: Move this logic to logger_connector. This also needs to be fixed for any
# other monitor logged value which aren't produced from a Metric.
if self.monitor == "val_loss":
current = trainer.training_type_plugin.reduce(current, reduce_op="mean")

if self.check_monitor_top_k(current):
self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)
elif self.verbose:
Expand Down
29 changes: 27 additions & 2 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,37 @@
from typing import Optional, Union

if importlib.util.find_spec('ipywidgets') is not None:
from tqdm.auto import tqdm
from tqdm.auto import tqdm as _tqdm
else:
from tqdm import tqdm
from tqdm import tqdm as _tqdm

from pytorch_lightning.callbacks import Callback

_PAD_SIZE = 5


class tqdm(_tqdm):
"""
Custom tqdm progressbar where we append 0 to floating points/strings to
prevent the progress bar from flickering
"""

@staticmethod
def format_num(n) -> str:
""" Add additional padding to the formatted numbers """
should_be_padded = isinstance(n, (float, str))
if not isinstance(n, str):
n = _tqdm.format_num(n)
if should_be_padded and 'e' not in n:
if '.' not in n and len(n) < _PAD_SIZE:
try:
_ = float(n)
except ValueError:
return n
n += '.'
n += "0" * (_PAD_SIZE - len(n))
return n


class ProgressBarBase(Callback):
r"""
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,10 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size())
torch_backend = "nccl" if self.on_gpu else "gloo"

if not torch.distributed.is_initialized():
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)

def pre_dispatch(self):
# TODO: check if needed
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,10 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size())
torch_backend = "nccl" if self.on_gpu else "gloo"

if not torch.distributed.is_initialized():
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)

def determine_ddp_device_ids(self):
if self.root_device.type == "cpu":
Expand Down
9 changes: 9 additions & 0 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import List, Optional
Expand All @@ -23,6 +24,7 @@
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp


Expand Down Expand Up @@ -82,6 +84,13 @@ def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
should_stop = bool(should_stop == self.world_size)
return should_stop

@property
def torch_distributed_backend(self):
torch_backend = os.getenv("PL_TORCH_DISTRIBUTED_BACKEND")
if torch_backend is None:
torch_backend = "nccl" if self.on_gpu else "gloo"
return torch_backend

@staticmethod
def configure_sync_batchnorm(model: LightningModule) -> LightningModule:
"""
Expand Down
31 changes: 26 additions & 5 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything

if _TPU_AVAILABLE:
Expand Down Expand Up @@ -46,10 +47,6 @@ def create_mp_queue(self):
def distributed_sampler_kwargs(self) -> dict:
return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())

@property
def should_finalize(self):
return self.world_size == 1

@property
def is_distributed(self):
return self.world_size != 1
Expand Down Expand Up @@ -179,6 +176,24 @@ def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
should_stop = int(stop.item()) == self.world_size
return should_stop

def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
if not isinstance(output, torch.Tensor):
output = torch.tensor(output, device=self.device)

_invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
_invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
if _invalid_reduce_op or _invalid_reduce_op_str:
raise MisconfigurationException(
"Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation."
)

output = xm.mesh_reduce('reduce', output, sum)

if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
output = output / self.world_size

return output

def post_dispatch(self) -> None:
# TODO: Check if trainer references can be resolved otherwise
model = self.lightning_module
Expand Down Expand Up @@ -213,6 +228,10 @@ def __load_weights_on_main_process(self) -> None:

self._model = model

def _close_logger(self, trainer) -> None:
if hasattr(trainer, "logger"):
trainer.logger.finalize("success")

@property
def xmp_spawn_kwargs(self):
return {
Expand All @@ -225,9 +244,11 @@ def start_training(self, trainer) -> None:
# todo: precision pluging is call in accelerator setup and should be moved
if 'XLA_USE_BF16' in os.environ:
del os.environ["XLA_USE_BF16"]
self._close_logger(trainer)
xmp.spawn(self.new_process, **self.xmp_spawn_kwargs)

def start_testing(self, trainer) -> None:
self._close_logger(trainer)
xmp.spawn(self.new_process, **self.xmp_spawn_kwargs)

def start_predicting(self, trainer) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ def __init__(self) -> None:
self._results = None
self.global_rank = 0

@property
def should_finalize(self):
return True

@property
@abstractmethod
def on_gpu(self) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
import horovod.torch as hvd


class BackendConnector(object):
class AcceleratorConnector(object):

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
from copy import deepcopy
from pprint import pprint
from typing import Dict, Iterable, Union
from typing import Dict, Iterable, Optional, Union

import torch

Expand All @@ -32,7 +32,7 @@

class LoggerConnector:

def __init__(self, trainer, log_gpu_memory: bool):
def __init__(self, trainer, log_gpu_memory: Optional[str] = None):
self.trainer = trainer
self.log_gpu_memory = log_gpu_memory
self._callback_metrics = MetricsHolder()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.accelerators.accelerator_connector import BackendConnector
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn

Expand All @@ -22,7 +22,7 @@ class DeprecatedDistDeviceAttributes:
_device_type: DeviceType
_running_stage: RunningStage
num_gpus: int
accelerator_connector: BackendConnector
accelerator_connector: AcceleratorConnector

@property
def on_cpu(self) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ def on_evaluation_epoch_end(self, *args, **kwargs):
else:
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)

self.trainer.call_hook('on_epoch_end')

def log_evaluation_step_metrics(self, output, batch_idx):
if self.trainer.running_sanity_check:
return
Expand Down
9 changes: 7 additions & 2 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.optim import Optimizer

from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.accelerators.accelerator_connector import BackendConnector
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
Expand Down Expand Up @@ -51,7 +51,7 @@ class TrainerProperties(ABC):
_state: TrainerState
_weights_save_path: str

accelerator_connector: BackendConnector
accelerator_connector: AcceleratorConnector
callbacks: List[Callback]
checkpoint_connector: CheckpointConnector
limit_val_batches: int
Expand Down Expand Up @@ -373,6 +373,11 @@ def optimizers(self) -> Optional[List[Optimizer]]:

@optimizers.setter
def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None:
# Necessary to rewrap optimizers to lightning
# They will be re-created when accessing
# the `lightning_optimizers` trainer property
self._lightning_optimizers = None

self.accelerator.optimizers = new_optims

@property
Expand Down
Loading

0 comments on commit c5e5785

Please sign in to comment.