Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: clean trainer device & distrib getters #5300

Merged
merged 17 commits into from
Jan 12, 2021
37 changes: 24 additions & 13 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,21 @@ def select_accelerator(self):
# ----------------------------------
# choose an accelerator for the user
# ----------------------------------
use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks
use_slurm_ddp = (
self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN)
and self.trainer.is_slurm_managing_tasks
)

# torchelastic or general non_slurm ddp
te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ)
use_torchelastic_ddp = self.trainer.use_ddp and te_flags_passed
use_torchelastic_ddp = (
self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) and te_flags_passed
)

use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_spawn"
use_ddp_cpu_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_cpu"
use_ddp_cpu_spawn = (
self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN)
and self.trainer._device_type == DeviceType.CPU
)

use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self._is_using_torchelastic()
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.trainer.is_slurm_managing_tasks
Expand All @@ -204,8 +211,9 @@ def select_accelerator(self):

cluster_env = self._select_environment()

# TODO: clean-up this branching as most just select class and uses the very same arguments
# choose the appropriate accelerator backend
if self.trainer.use_ddp2:
if self.trainer._distrib_type == DistributedType.DDP2:
accelerator_backend = accelerators.DDP2Accelerator(
self.trainer,
cluster_env,
Expand Down Expand Up @@ -240,7 +248,7 @@ def select_accelerator(self):
self.trainer.plugin_connector.ddp_plugin
)

elif use_ddp_spawn:
elif self.trainer._distrib_type == DistributedType.DDP_SPAWN:
accelerator_backend = accelerators.DDPSpawnAccelerator(
self.trainer,
nprocs=self.trainer.num_processes,
Expand All @@ -263,16 +271,16 @@ def select_accelerator(self):
ddp_plugin=self.trainer.plugin_connector.ddp_plugin
)

elif self.trainer.use_dp:
elif self.trainer._distrib_type == DistributedType.DP:
accelerator_backend = accelerators.DataParallelAccelerator(self.trainer, cluster_env)

elif self.trainer.use_horovod:
elif self.trainer._distrib_type == DistributedType.HOROVOD:
accelerator_backend = accelerators.HorovodAccelerator(self.trainer, cluster_env)

elif self.trainer.use_single_gpu:
elif self.trainer._device_type == DeviceType.GPU and self.trainer.num_gpus == 1:
accelerator_backend = accelerators.GPUAccelerator(self.trainer, cluster_env)

elif self.trainer.use_tpu:
elif self.trainer._device_type == DeviceType.TPU:
accelerator_backend = accelerators.TPUAccelerator(self.trainer, cluster_env)

elif self.trainer.distributed_backend is None:
Expand Down Expand Up @@ -347,13 +355,16 @@ def set_distributed_mode(self):
self._set_horovod_backend()

# throw error to force user ddp or ddp2 choice
if self.trainer.num_nodes > 1 and self.trainer._distrib_type not in (DistributedType.DDP2, DistributedType.DDP):
_ddp = (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
if (self.trainer.num_nodes > 1 and self.trainer._distrib_type not in _ddp):
raise MisconfigurationException(
'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. '
'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`'
)

rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self.trainer.on_gpu}')
rank_zero_info(
f'GPU available: {torch.cuda.is_available()}, used: {self.trainer._device_type == DeviceType.GPU}'
)
num_cores = self.trainer.tpu_cores if self.trainer.tpu_cores is not None else 0
rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores')

Expand All @@ -366,7 +377,7 @@ def _set_horovod_backend(self):

# Initialize Horovod to get rank / size info
hvd.init()
if self.trainer.on_gpu:
if self.trainer._device_type == DeviceType.GPU:
# Horovod assigns one local GPU per process
self.trainer.root_gpu = hvd.local_rank()

Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/accelerators/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, AMPType
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, AMPType, DeviceType
from pytorch_lightning.utilities.distributed import rank_zero_only

if _HOROVOD_AVAILABLE:
Expand All @@ -46,7 +46,7 @@ def setup(self, model):
# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

if torch.cuda.is_available() and self.trainer.on_gpu:
if torch.cuda.is_available() and self.trainer._device_type == DeviceType.GPU:
# Horovod: pin GPU to local rank
assert self.trainer.root_gpu == hvd.local_rank()
torch.cuda.set_device(self.trainer.root_gpu)
Expand Down Expand Up @@ -116,7 +116,7 @@ def train(self):
return results

def _step(self, model_step: Callable, args):
if self.trainer.on_gpu:
if self.trainer._device_type == DeviceType.GPU:
args[0] = self.batch_to_device(args[0], hvd.local_rank())

if self.trainer.amp_backend == AMPType.NATIVE:
Expand All @@ -141,7 +141,7 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
optimizer.synchronize()

def on_train_epoch_end(self, outputs):
hvd.join(hvd.local_rank() if self.trainer.on_gpu else -1)
hvd.join(hvd.local_rank() if self.trainer._device_type == DeviceType.GPU else -1)

def barrier(self, name: Optional[str] = None):
hvd.join()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from typing import Dict, List, Tuple

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities import rank_zero_only, DeviceType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict

Expand Down Expand Up @@ -104,7 +104,7 @@ def on_train_start(self, trainer, *args, **kwargs):
'Cannot use GPUStatsMonitor callback with Trainer that has no logger.'
)

if not trainer.on_gpu:
if trainer._device_type != DeviceType.GPU:
raise MisconfigurationException(
'You are using GPUStatsMonitor but are not running on GPU'
f' since gpus attribute in Trainer is set to {trainer.gpus}.'
Expand Down
13 changes: 2 additions & 11 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,8 @@ def __init__(self, *args, **kwargs):
#: Pointer to the logger object
self.logger = None

#: True if using dp
self.use_dp = False

#: True if using ddp
self.use_ddp = False

#: True if using ddp2
self.use_ddp2 = False

# True if on tpu
self.use_tpu = False
self._distrib_type = None
self._device_type = None

#: True if using amp
self.use_amp = False
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch.nn as nn
from torch.utils.hooks import RemovableHandle

from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities import AMPType, DeviceType

PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
UNKNOWN_SIZE = "?"
Expand Down Expand Up @@ -229,7 +229,7 @@ def _forward_example_input(self) -> None:
input_ = model.example_input_array
input_ = model.transfer_batch_to_device(input_, model.device)

if trainer is not None and trainer.amp_backend == AMPType.NATIVE and not trainer.use_tpu:
if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU:
model.forward = torch.cuda.amp.autocast()(model.forward)

mode = model.training
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from torch.optim.optimizer import Optimizer

from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities import _TPU_AVAILABLE, DeviceType
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _TPU_AVAILABLE:
Expand Down Expand Up @@ -125,7 +125,7 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n
optimizer = self._optimizer
model = trainer.get_model()

if trainer.on_tpu:
if trainer._device_type == DeviceType.TPU:
with trainer.profiler.profile(profiler_name):
xm.optimizer_step(optimizer, optimizer_args={'closure': closure, **kwargs})

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def _worker(i, module, input, kwargs, device=None):
if output is None:
warn_missing_output(fx_called)

if output is not None and (module.use_dp or module.use_ddp2):
if output is not None and module._distrib_type in ('dp', 'ddp2'):
auto_squeeze_dim_zeros(output)
# ---------------

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.plugins.plugin import LightningPlugin
from pytorch_lightning.utilities import DeviceType


class DDPPlugin(LightningPlugin):
Expand Down Expand Up @@ -95,7 +96,7 @@ def init_ddp_connection(
os.environ["MASTER_ADDR"] = str(cluster_environment.master_address())
os.environ["MASTER_PORT"] = str(cluster_environment.master_port())
os.environ["WORLD_SIZE"] = str(cluster_environment.world_size())
torch_backend = "nccl" if trainer.on_gpu else "gloo"
torch_backend = "nccl" if trainer._device_type == DeviceType.GPU else "gloo"

if not torch_distrib.is_initialized():
log.info(
Expand Down
13 changes: 7 additions & 6 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

import pytorch_lightning
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, _OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities import (
_APEX_AVAILABLE, AMPType, _OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn, DeviceType)
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -50,26 +51,26 @@ def restore_weights(self, model: LightningModule) -> None:
3. don't restore
"""
# clear cache before restore
if self.trainer.on_gpu:
if self.trainer._device_type == DeviceType.GPU:
torch.cuda.empty_cache()

# 1. Attempt to restore states from HPC checkpoint
dir_path_hpc = str(self.trainer.weights_save_path)
max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_")
if max_suffix is not None:
checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt'
self.hpc_load(checkpoint_path, self.trainer.on_gpu)
self.hpc_load(checkpoint_path, self.trainer._device_type == DeviceType.GPU)
rank_zero_info(f'restored hpc model from: {checkpoint_path}')

# 2. Attempt to restore states from `resume_from_checkpoint` file
elif self.trainer.resume_from_checkpoint is not None and not self.trainer.testing:
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer._device_type == DeviceType.GPU)

# wait for all to catch up
self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights')

# clear cache after restore
if self.trainer.on_gpu:
if self.trainer._device_type == DeviceType.GPU:
torch.cuda.empty_cache()

def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
Expand Down Expand Up @@ -291,7 +292,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:

# dump amp scaling
if (self.trainer.amp_backend == AMPType.NATIVE
and not self.trainer.use_tpu
and self.trainer._device_type != DeviceType.TPU
and self.trainer.scaler is not None):
checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict()
elif self.trainer.amp_backend == AMPType.APEX:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import DistributedType


class LoggerStages(str, Enum):
Expand Down Expand Up @@ -343,7 +344,7 @@ def cache_result(self) -> None:
hook_result.detach()
if self.trainer.move_metrics_to_cpu:
hook_result.cpu()
elif self.trainer.use_dp:
elif self.trainer._distrib_type == DistributedType.DP:
hook_result.to(torch.device("cuda", self.trainer.root_gpu))

self._internals[fx_name].append(hook_result, dataloader_idx=dataloader_idx, extra_info=extra_info)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import os
from copy import deepcopy
from pprint import pprint
from typing import Any, Iterable, Union, Dict

Expand All @@ -24,7 +24,7 @@
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore, LoggerStages
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
from pytorch_lightning.utilities import flatten_dict
from pytorch_lightning.utilities import flatten_dict, DeviceType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden

Expand Down Expand Up @@ -81,7 +81,7 @@ def get_metrics(self, key: str) -> Dict:
metrics_holder = getattr(self, f"_{key}", None)
model_ref = self.trainer.get_model()
metrics_holder.convert(
self.trainer.use_tpu,
self.trainer._device_type == DeviceType.TPU,
model_ref.device if model_ref is not None else model_ref
)
return metrics_holder.metrics
Expand Down Expand Up @@ -219,7 +219,7 @@ def log_metrics(self, metrics, grad_norm_dic, step=None, log_train_step_metrics=
and global_step for the rest.
"""
# add gpu memory
if self.trainer.on_gpu and self.trainer.log_gpu_memory:
if self.trainer._device_type == DeviceType.GPU and self.trainer.log_gpu_memory:
mem_map = memory.get_memory_profile(self.trainer.log_gpu_memory)
metrics.update(mem_map)

Expand Down
7 changes: 2 additions & 5 deletions pytorch_lightning/trainer/connectors/model_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,10 @@ def copy_trainer_model_properties(self, model):
for m in [model, ref_model]:
m.trainer = self.trainer
m.logger = self.trainer.logger
m.use_dp = self.trainer.use_dp
m.use_ddp2 = self.trainer.use_ddp2
m.use_ddp = self.trainer.use_ddp
m._device_type = str(self.trainer._device_type)
m._distrib_type = str(self.trainer._distrib_type)
m.use_amp = self.trainer.amp_backend is not None
m.testing = self.trainer.testing
m.use_single_gpu = self.trainer.use_single_gpu
m.use_tpu = self.trainer.use_tpu
m.tpu_local_core_rank = self.trainer.tpu_local_core_rank
m.tpu_global_core_rank = self.trainer.tpu_global_core_rank
m.precision = self.trainer.precision
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/connectors/slurm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import signal
from subprocess import call
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import DeviceType, DistributedType
from pytorch_lightning.utilities.distributed import rank_zero_info
import torch.distributed as torch_distrib
import torch
Expand All @@ -22,7 +23,7 @@ def configure_slurm_ddp(self, num_gpu_nodes):
# extract SLURM flag vars
# whenever we have the correct number of tasks, we let slurm manage processes
# otherwise we launch the required number of processes
if self.trainer.use_ddp or self.trainer.use_ddp2:
if self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2):
self.trainer.num_requested_gpus = self.trainer.num_gpus * num_gpu_nodes
self.trainer.num_slurm_tasks = 0
try:
Expand Down Expand Up @@ -145,7 +146,7 @@ def connect_ddp(self, global_rank: int, world_size: int) -> None:
os.environ["MASTER_ADDR"] = root_node
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")

torch_backend = "nccl" if self.trainer.on_gpu else "gloo"
torch_backend = "nccl" if self.trainer._device_type == DeviceType.GPU else "gloo"

if not torch.distributed.is_initialized():
log.info(
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,8 @@ class TrainerDataLoadingMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
global_rank: int
use_ddp: bool
use_ddp2: bool
use_horovod: bool
shown_warnings: ...
val_check_interval: float
use_tpu: bool
tpu_local_core_rank: int
train_dataloader: DataLoader
num_training_batches: Union[int, float]
Expand Down
Loading