diff --git a/pytorch_lightning/callbacks/timer.py b/pytorch_lightning/callbacks/timer.py index ef7e5866542258..c95e4d1c4a49ab 100644 --- a/pytorch_lightning/callbacks/timer.py +++ b/pytorch_lightning/callbacks/timer.py @@ -165,7 +165,7 @@ def on_load_checkpoint( def _check_time_remaining(self, trainer: "pl.Trainer") -> None: should_stop = self.time_elapsed() >= self._duration - should_stop = trainer.accelerator.broadcast(should_stop) + should_stop = trainer.accelerator.training_type_plugin.broadcast(should_stop) trainer.should_stop = trainer.should_stop or should_stop if should_stop and self._verbose: elapsed = timedelta(seconds=int(self.time_elapsed(RunningStage.TRAINING))) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c241863605e6e1..f5853fd887c458 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -594,7 +594,7 @@ def all_gather( the output will also be a collection with tensors of this shape. """ group = group if group is not None else torch.distributed.group.WORLD - all_gather = self.trainer.accelerator.all_gather + all_gather = self.trainer.accelerator.training_type_plugin.all_gather data = convert_to_tensors(data, device=self.device) return apply_to_collection(data, torch.Tensor, all_gather, group=group, sync_grads=sync_grads) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index df0f658bf712aa..ff527ed01692d4 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -31,9 +31,9 @@ import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward +from pytorch_lightning.overrides.torch_distributed import broadcast_object_list from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin @@ -48,13 +48,9 @@ rank_zero_deprecation, rank_zero_warn, ) -from pytorch_lightning.utilities.distributed import ( - distributed_available, - init_ddp_connection, - rank_zero_only, - ReduceOp, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import distributed_available +from pytorch_lightning.utilities.distributed import group as _group +from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -116,7 +112,6 @@ def __init__( " Notice that it will be overriden by the trainer setting." ) self._sync_batchnorm = sync_batchnorm or False - self.dist = LightningDistributed() self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0 self._ddp_kwargs = kwargs self._task_idx = None @@ -270,8 +265,6 @@ def setup_distributed(self): init_ddp_connection(self.cluster_environment, self.torch_distributed_backend) # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device def _check_can_spawn_children(self): if self.local_rank != 0: @@ -403,7 +396,15 @@ def barrier(self, *args, **kwargs) -> None: torch.distributed.barrier() def broadcast(self, obj: object, src: int = 0) -> object: - return self.dist.broadcast(obj) + if not distributed_available(): + raise RuntimeError( + "DDPSpawn is not initialized and torch.distributed is not avalible, can not broadcast object" + ) + obj = [obj] + if self.global_rank != 0: + obj = [None] * len(obj) + broadcast_object_list(obj, src, group=_group.WORLD) + return obj[0] def pre_backward(self, closure_loss: torch.Tensor) -> None: """Run before precision plugin executes backward.""" diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 5f493001341d67..e46606402d8ca0 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -24,9 +24,9 @@ from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl -from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward +from pytorch_lightning.overrides.torch_distributed import broadcast_object_list from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin @@ -40,13 +40,9 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.distributed import ( - distributed_available, - init_ddp_connection, - rank_zero_only, - ReduceOp, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import distributed_available +from pytorch_lightning.utilities.distributed import group as _group +from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -93,7 +89,6 @@ def __init__( ) self._sync_batchnorm = sync_batchnorm or False self._ddp_kwargs = kwargs - self.dist = LightningDistributed() self.num_processes = len(parallel_devices) if parallel_devices is not None else 0 self.mp_queue = None self._ddp_comm_state = ddp_comm_state @@ -193,10 +188,6 @@ def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQ # ... need to double check that it is the correct place # self.trainer.call_setup_hook(self.model) - # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device - # move the model to the correct device self.model_to_device() @@ -324,7 +315,11 @@ def barrier(self, *args, **kwargs) -> None: def broadcast(self, obj: object, src: int = 0) -> object: if not distributed_available(): return obj - return self.dist.broadcast(obj) + obj = [obj] + if self.global_rank != 0: + obj = [None] * len(obj) + broadcast_object_list(obj, src, group=_group.WORLD) + return obj[0] def model_to_device(self): if self.root_device.type == "cuda": diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index cb3b007b712ff6..978152506d0e34 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -342,9 +342,6 @@ def setup_distributed(self): self._init_deepspeed_distributed() - # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device if not self._config_initialized: self._format_config() self._config_initialized = True diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 13d6f93f5fb972..7212d26136cfc1 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, TypeVar, Union import torch from torch import Tensor @@ -25,6 +25,7 @@ from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT TBroadcast = TypeVar("T") @@ -91,26 +92,53 @@ def is_global_zero(self) -> bool: """Whether the current process is the rank zero process not only on the local node, but for all nodes.""" @abstractmethod - def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: + def reduce( + self, + tensor: Union[torch.Tensor, Any], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Union[torch.Tensor, Any]: """Reduces the given tensor (e.g. across GPUs/processes). Args: tensor: the tensor to sync and reduce + group: the process group to reduce + reduce_op: the reduction operation. Defaults to 'mean'. + Can also be a string 'sum' or ReduceOp. *args: plugin-specific positional arguments **kwargs: plugin-specific keyword arguments """ @abstractmethod def barrier(self, name: Optional[str] = None) -> None: - """Forces all possibly joined processes to wait for each other.""" + """Synchronizes all processes which blocks processes until the whole group enters this function. + + Args: + name: a str pass into barrier. Only torch xla respect this param + """ @abstractmethod - def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - """Broadcasts an object to all processes.""" + def broadcast(self, obj: object, src: int = 0) -> object: + """Broadcasts an object to all processes. + + Args: + obj: the object to broadcast + src: source rank. + """ @abstractmethod - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """Perform a all_gather on all processes.""" + def all_gather( + self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False + ) -> Union[List[torch.Tensor], torch.Tensor]: + """Perform a all_gather on all processes. + + Args: + tensor: the tensor to all_gather + group: the process group to gather results from + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Returns: a tensor (torch distributed) or a list of tensor (horovod) + """ def reduce_boolean_decision(self, decision: bool) -> bool: """Reduce the early stopping decision across all processes.""" diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 969404e68c498e..61160093b7d085 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -531,7 +531,7 @@ def request_dataloader( dataloader = self.call_hook(hook, pl_module=model) if isinstance(dataloader, tuple): dataloader = list(dataloader) - self.accelerator.barrier("get_dataloaders") + self.accelerator.training_type_plugin.barrier("get_dataloaders") return dataloader @staticmethod diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py new file mode 100644 index 00000000000000..ac4b344d92d058 --- /dev/null +++ b/pytorch_lightning/trainer/properties.py @@ -0,0 +1,661 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import os +from abc import ABC +from argparse import ArgumentParser, Namespace +from pathlib import Path +from typing import cast, List, Optional, Type, TypeVar, Union + +import torch +from torch.optim import Optimizer + +import pytorch_lightning as pl +from pytorch_lightning.accelerators import Accelerator +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter +from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.loggers.base import LoggerCollection +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from pytorch_lightning.loops import PredictionLoop +from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop +from pytorch_lightning.loops.fit_loop import FitLoop +from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin +from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector +from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector +from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection +from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus +from pytorch_lightning.utilities import DeviceType, DistributedType, GradClipAlgorithmType, rank_zero_deprecation +from pytorch_lightning.utilities.argparse import ( + add_argparse_args, + from_argparse_args, + parse_argparser, + parse_env_variables, +) +from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.types import _PATH + + +class TrainerProperties(ABC): + + _default_root_dir: str + _fit_loop: FitLoop + _lightning_optimizers = None + _predict_loop: PredictionLoop + _progress_bar_callback: ProgressBarBase + _test_loop: EvaluationLoop + _validate_loop: EvaluationLoop + _weights_save_path: str + + accelerator_connector: AcceleratorConnector + accumulate_grad_batches: int + callbacks: List[Callback] + checkpoint_connector: CheckpointConnector + gradient_clip_algorithm: GradClipAlgorithmType + gradient_clip_val: float + limit_val_batches: int + logger: Optional[LightningLoggerBase] + logger_connector: LoggerConnector + reload_dataloaders_every_n_epochs: int + state: TrainerState + terminate_on_nan: bool + track_grad_norm: Union[int, float, str] + + # .validate() and .test() set this when they load a checkpoint + validated_ckpt_path: Optional[str] = None + tested_ckpt_path: Optional[str] = None + predicted_ckpt_path: Optional[str] = None + """ + Accelerator properties + """ + + @property + def accelerator(self) -> Accelerator: + return self.accelerator_connector.accelerator + + @property + def distributed_backend(self) -> Optional[str]: + # for backward compatibility + return self.accelerator_connector.distributed_backend + + @property + def training_type_plugin(self) -> TrainingTypePlugin: + return self.accelerator.training_type_plugin + + @property + def precision_plugin(self) -> PrecisionPlugin: + return self.accelerator.precision_plugin + + @property + def global_rank(self) -> int: + return self.accelerator.training_type_plugin.global_rank + + @property + def local_rank(self) -> int: + # some training types define a local rank + return getattr(self.accelerator.training_type_plugin, "local_rank", 0) + + @property + def node_rank(self) -> int: + # some training types define a local rank + return getattr(self.accelerator.training_type_plugin, "node_rank", 0) + + @property + def world_size(self) -> int: + # some training types define a world size + return getattr(self.accelerator.training_type_plugin, "world_size", 1) + + @property + def should_rank_save_checkpoint(self) -> bool: + return self.accelerator.training_type_plugin.should_rank_save_checkpoint + + @property + def _distrib_type(self) -> DistributedType: + return self.accelerator_connector._distrib_type + + @property + def _device_type(self) -> DeviceType: + return self.accelerator_connector._device_type + + @property + def num_nodes(self) -> int: + return self.accelerator_connector.num_nodes + + @property + def num_processes(self) -> int: + return self.accelerator_connector.num_processes + + @property + def root_gpu(self) -> Optional[int]: + return self.accelerator_connector.root_gpu + + @property + def tpu_cores(self) -> int: + return self.accelerator_connector.tpu_cores + + @property + def ipus(self) -> int: + return self.accelerator_connector.num_ipus + + @property + def num_gpus(self) -> int: + return self.accelerator_connector.num_gpus + + @property + def devices(self) -> Optional[Union[List[int], str, int]]: + return self.accelerator_connector.devices + + @property + def data_parallel_device_ids(self) -> Optional[List[int]]: + return self.accelerator_connector.parallel_device_ids + + @property + def lightning_module(self) -> "pl.LightningModule": + return self.accelerator.lightning_module + + @property + def optimizers(self) -> List[Optimizer]: + return self.accelerator.optimizers + + @optimizers.setter + def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: + # Necessary to rewrap optimizers to lightning + # They will be re-created when accessing + # the `lightning_optimizers` trainer property + self._lightning_optimizers = None + + self.accelerator.optimizers = new_optims + + @property + def lr_schedulers(self) -> Optional[list]: + return self.accelerator.lr_schedulers + + @lr_schedulers.setter + def lr_schedulers(self, new_schedulers: Optional[list]) -> None: + self.accelerator.lr_schedulers = new_schedulers + + @property + def optimizer_frequencies(self) -> list: + return self.accelerator.optimizer_frequencies + + @optimizer_frequencies.setter + def optimizer_frequencies(self, new_freqs: list) -> None: + self.accelerator.optimizer_frequencies = new_freqs + + @property + def amp_backend(self) -> Optional[str]: + return self.accelerator.amp_backend + + @property + def precision(self) -> Union[str, int]: + return self.accelerator.precision + + @property + def scaler(self): + return self.accelerator.scaler + + @property + def gpus(self) -> Optional[Union[List[int], str, int]]: + return self.accelerator_connector.gpus + + @property + def model(self) -> torch.nn.Module: + """The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel. + + To access the pure LightningModule, use + :meth:`~pytorch_lightning.trainer.trainer.Trainer.lightning_module` instead. + """ + return self.accelerator.model + + @model.setter + def model(self, model: torch.nn.Module) -> None: + """Setter for the model, pass-through to accelerator and plugin where the model reference is stored. Used + by the Tuner to reset the state of Trainer and Accelerator. + + Args: + model: The LightningModule, possibly wrapped into DataParallel or DistributedDataParallel, depending + on the backend. + """ + self.accelerator.model = model + + """ + General properties + """ + + @property + def log_dir(self) -> Optional[str]: + if self.logger is None: + dirpath = self.default_root_dir + elif isinstance(self.logger, TensorBoardLogger): + dirpath = self.logger.log_dir + elif isinstance(self.logger, LoggerCollection): + dirpath = self.default_root_dir + else: + dirpath = self.logger.save_dir + + dirpath = self.accelerator.training_type_plugin.broadcast(dirpath) + return dirpath + + @property + def use_amp(self) -> bool: + return self.precision == 16 + + @property + def is_global_zero(self) -> bool: + return self.global_rank == 0 + + @property + def slurm_job_id(self) -> Optional[int]: + job_id = os.environ.get("SLURM_JOB_ID") + if job_id: + try: + job_id = int(job_id) + except ValueError: + job_id = None + + # in interactive mode, don't make logs use the same job id + in_slurm_interactive_mode = os.environ.get("SLURM_JOB_NAME") == "bash" + if in_slurm_interactive_mode: + job_id = None + return job_id + + @property + def lightning_optimizers(self) -> List[LightningOptimizer]: + if self._lightning_optimizers is None: + self.convert_to_lightning_optimizers() + return self._lightning_optimizers + + @property + def distributed_sampler_kwargs(self) -> Optional[dict]: + if isinstance(self.training_type_plugin, ParallelPlugin): + return self.training_type_plugin.distributed_sampler_kwargs + + @property + def data_parallel(self) -> bool: + return self._distrib_type in ( + DistributedType.DP, + DistributedType.DDP, + DistributedType.DDP_SPAWN, + DistributedType.DDP2, + ) + + @property + def progress_bar_callback(self) -> Optional[ProgressBarBase]: + return self._progress_bar_callback + + @property + def progress_bar_dict(self) -> dict: + """Read-only for progress bar metrics.""" + rank_zero_deprecation( + "`trainer.progress_bar_dict` is deprecated in v1.5 and will be removed in v1.7." + " Use `ProgressBarBase.get_metrics` instead." + ) + ref_model = self.lightning_module + ref_model = cast(pl.LightningModule, ref_model) + if self.progress_bar_callback: + return self.progress_bar_callback.get_metrics(self, ref_model) + return self.progress_bar_metrics + + @property + def _should_reload_dl_epoch(self) -> bool: + """Check if dataloader should be reloaded in the current epoch.""" + n_epochs = self.reload_dataloaders_every_n_epochs + return n_epochs and (not self.current_epoch % n_epochs) + + @property + def disable_validation(self) -> bool: + """Check if validation is disabled during training.""" + rank_zero_deprecation( + "`trainer.disable_validation` is deprecated in v1.4 and will be removed in v1.6." + " Use `not trainer.enable_validation` instead." + ) + return not self.enable_validation + + @property + def enable_validation(self) -> bool: + """Check if we should run validation during training.""" + model_ref = self.lightning_module + val_loop_enabled = is_overridden("validation_step", model_ref) and self.limit_val_batches > 0 + return val_loop_enabled + + @property + def default_root_dir(self) -> str: + """The default location to save artifacts of loggers, checkpoints etc. + + It is used as a fallback if logger or checkpoint callback do not define specific save paths. + """ + if get_filesystem(self._default_root_dir).protocol == "file": + return os.path.normpath(self._default_root_dir) + return self._default_root_dir + + @property + def weights_save_path(self) -> str: + """ + The default root location to save weights (checkpoints), e.g., when the + :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path. + """ + if get_filesystem(self._weights_save_path).protocol == "file": + return os.path.normpath(self._weights_save_path) + return self._weights_save_path + + @property + def early_stopping_callback(self) -> Optional[EarlyStopping]: + """The first :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` callback in the + Trainer.callbacks list, or ``None`` if it doesn't exist.""" + callbacks = self.early_stopping_callbacks + return callbacks[0] if len(callbacks) > 0 else None + + @property + def early_stopping_callbacks(self) -> List[EarlyStopping]: + """A list of all instances of :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` found in + the Trainer.callbacks list.""" + return [c for c in self.callbacks if isinstance(c, EarlyStopping)] + + @property + def prediction_writer_callbacks(self) -> List[BasePredictionWriter]: + """A list of all instances of :class:`~pytorch_lightning.callbacks.prediction_writer.BasePredictionWriter` + found in the Trainer.callbacks list.""" + return [cb for cb in self.callbacks if isinstance(cb, BasePredictionWriter)] + + @property + def checkpoint_callback(self) -> Optional[ModelCheckpoint]: + """The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback in the + Trainer.callbacks list, or ``None`` if it doesn't exist.""" + callbacks = self.checkpoint_callbacks + return callbacks[0] if len(callbacks) > 0 else None + + @property + def checkpoint_callbacks(self) -> List[ModelCheckpoint]: + """A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` found + in the Trainer.callbacks list.""" + return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] + + @property + def resume_from_checkpoint(self) -> Optional[Union[str, Path]]: + return self.checkpoint_connector.resume_checkpoint_path + + def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None: + self.checkpoint_connector.save_checkpoint(filepath, weights_only) + + """ + Parsing properties + """ + + @classmethod + def default_attributes(cls) -> dict: + init_signature = inspect.signature(cls) + return {k: v.default for k, v in init_signature.parameters.items()} + + @classmethod + def get_deprecated_arg_names(cls) -> List: + """Returns a list with deprecated Trainer arguments.""" + depr_arg_names = [] + for name, val in cls.__dict__.items(): + if name.startswith("DEPRECATED") and isinstance(val, (tuple, list)): + depr_arg_names.extend(val) + return depr_arg_names + + @classmethod + def from_argparse_args(cls: Type["_T"], args: Union[Namespace, ArgumentParser], **kwargs) -> "_T": + return from_argparse_args(cls, args, **kwargs) + + @classmethod + def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: + return parse_argparser(cls, arg_parser) + + @classmethod + def match_env_arguments(cls) -> Namespace: + return parse_env_variables(cls) + + @classmethod + def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: + return add_argparse_args(cls, parent_parser, **kwargs) + + """ + State properties + """ + + @property + def interrupted(self) -> bool: + return self.state.status == TrainerStatus.INTERRUPTED + + @property + def training(self) -> bool: + return self.state.stage == RunningStage.TRAINING + + @training.setter + def training(self, val: bool) -> None: + if val: + self.state.stage = RunningStage.TRAINING + elif self.training: + self.state.stage = None + + @property + def testing(self) -> bool: + return self.state.stage == RunningStage.TESTING + + @testing.setter + def testing(self, val: bool) -> None: + if val: + self.state.stage = RunningStage.TESTING + elif self.testing: + self.state.stage = None + + @property + def predicting(self) -> bool: + return self.state.stage == RunningStage.PREDICTING + + @predicting.setter + def predicting(self, val: bool) -> None: + if val: + self.state.stage = RunningStage.PREDICTING + elif self.predicting: + self.state.stage = None + + @property + def tuning(self) -> bool: + return self.state.stage == RunningStage.TUNING + + @tuning.setter + def tuning(self, val: bool) -> None: + if val: + self.state.stage = RunningStage.TUNING + elif self.tuning: + self.state.stage = None + + @property + def validating(self) -> bool: + return self.state.stage == RunningStage.VALIDATING + + @validating.setter + def validating(self, val: bool) -> None: + if val: + self.state.stage = RunningStage.VALIDATING + elif self.validating: + self.state.stage = None + + @property + def evaluating(self) -> bool: + return self.state.stage and self.state.stage.evaluating + + @property + def sanity_checking(self) -> bool: + return self.state.stage == RunningStage.SANITY_CHECKING + + @sanity_checking.setter + def sanity_checking(self, val: bool) -> None: + if val: + self.state.stage = RunningStage.SANITY_CHECKING + elif self.sanity_checking: + self.state.stage = None + + """ + Loop properties + """ + + @property + def global_step(self) -> int: + return self.fit_loop.global_step + + @property + def current_epoch(self) -> int: + return self.fit_loop.current_epoch + + @property + def max_epochs(self) -> Optional[int]: + return self.fit_loop.max_epochs + + @property + def min_epochs(self) -> Optional[int]: + return self.fit_loop.min_epochs + + @property + def max_steps(self) -> Optional[int]: + return self.fit_loop.max_steps + + @property + def min_steps(self) -> Optional[int]: + return self.fit_loop.min_steps + + @property + def is_last_batch(self) -> bool: + return self.fit_loop.epoch_loop.is_last_batch + + @property + def fit_loop(self) -> FitLoop: + return self._fit_loop + + @fit_loop.setter + def fit_loop(self, loop: FitLoop): + """Attach a custom fit loop to this Trainer. + + It will run with + :meth:`~pytorch_lighting.trainer.trainer.Trainer.fit`. + """ + loop.trainer = self + self._fit_loop = loop + + @property + def validate_loop(self) -> EvaluationLoop: + return self._validate_loop + + @validate_loop.setter + def validate_loop(self, loop: EvaluationLoop): + """Attach a custom validation loop to this Trainer. + + It will run with + :meth:`~pytorch_lighting.trainer.trainer.Trainer.validate`. Note that this loop is different from the one + running during training inside the :meth:`pytorch_lightning.trainer.trainer.Trainer.fit` call. + """ + loop.trainer = self + self._validate_loop = loop + + @property + def test_loop(self) -> EvaluationLoop: + return self._test_loop + + @test_loop.setter + def test_loop(self, loop: EvaluationLoop): + """Attach a custom test loop to this Trainer. + + It will run with + :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`. + """ + loop.trainer = self + self._test_loop = loop + + @property + def predict_loop(self) -> PredictionLoop: + return self._predict_loop + + @predict_loop.setter + def predict_loop(self, loop: PredictionLoop): + """Attach a custom prediction loop to this Trainer. + + It will run with + :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. + """ + loop.trainer = self + self._predict_loop = loop + + @property + def _evaluation_loop(self) -> EvaluationLoop: + if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): + return self.fit_loop.epoch_loop.val_loop + if self.state.fn == TrainerFn.VALIDATING: + return self.validate_loop + if self.state.fn == TrainerFn.TESTING: + return self.test_loop + raise RuntimeError("The `Trainer._evaluation_loop` property isn't defined. Accessed outside of scope") + + @property + def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop]]: + if self.training: + return self.fit_loop + if self.sanity_checking or self.evaluating: + return self._evaluation_loop + if self.predicting: + return self.predict_loop + + @property + def _ckpt_path(self) -> Optional[str]: + if self.state.fn == TrainerFn.VALIDATING: + return self.validated_ckpt_path + if self.state.fn == TrainerFn.TESTING: + return self.tested_ckpt_path + if self.state.fn == TrainerFn.PREDICTING: + return self.predicted_ckpt_path + + """ + Logging properties + """ + + @property + def callback_metrics(self) -> dict: + return self.logger_connector.callback_metrics + + @property + def logged_metrics(self) -> dict: + return self.logger_connector.logged_metrics + + @property + def progress_bar_metrics(self) -> dict: + return self.logger_connector.progress_bar_metrics + + @property + def _results(self) -> Optional[ResultCollection]: + active_loop = self._active_loop + if active_loop is not None: + return active_loop._results + + """ + Other + """ + + # TODO: refactor this so that it can be done in LightningOptimizer + def __getstate__(self): + # remove lightning_optimizers + self._lightning_optimizers = None + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state + + +# Used to represent the concrete type TrainerProperties class methods are called on. +_T = TypeVar("_T", bound=TrainerProperties) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 581ff11554cb35..424b2db5202000 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -958,7 +958,7 @@ def _load_checkpoint_weights(self): # only one process running at this point for TPUs, as spawn isn't triggered yet # todo: move this logic internally within the barrier. if not self._device_type == DeviceType.TPU: - self.accelerator.barrier() + self.accelerator.training_type_plugin.barrier() rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}") self.checkpoint_connector.restore_model_weights(self._ckpt_path) @@ -1141,7 +1141,7 @@ def run_stage(self): def _pre_training_routine(self): # wait for all to join if on distributed - self.accelerator.barrier("setup_training") + self.accelerator.training_type_plugin.barrier("setup_training") # register signals self.signal_connector.register_signal_handlers() @@ -1282,13 +1282,13 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_ def _call_setup_hook(self) -> None: fn = self.state.fn._setup_fn - self.accelerator.barrier("pre_setup") + self.accelerator.training_type_plugin.barrier("pre_setup") if self.datamodule is not None: self.datamodule.setup(stage=fn) self.call_hook("setup", stage=fn) - self.accelerator.barrier("post_setup") + self.accelerator.training_type_plugin.barrier("post_setup") def _call_configure_sharded_model(self) -> None: # Call configure sharded model hook if accelerator requests. In some cases @@ -1606,7 +1606,7 @@ def log_dir(self) -> Optional[str]: else: dirpath = self.logger.save_dir - dirpath = self.accelerator.broadcast(dirpath) + dirpath = self.accelerator.training_type_plugin.broadcast(dirpath) return dirpath @property diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index bd13275e9e5d19..6faa92f95346be 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -57,11 +57,11 @@ def test_ddp_with_2_gpus(): class BarrierModel(BoringModel): def setup(self, stage=None): assert not isinstance(self.trainer.accelerator.model, DistributedDataParallel) - self.trainer.accelerator.barrier("barrier before model is wrapped") + self.trainer.accelerator.training_type_plugin.barrier("barrier before model is wrapped") def on_train_start(self): assert isinstance(self.trainer.accelerator.model, DistributedDataParallel) - self.trainer.accelerator.barrier("barrier after model is wrapped") + self.trainer.accelerator.training_type_plugin.barrier("barrier after model is wrapped") @RunIf(min_gpus=4, special=True) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index c7ccaab3e72f46..d27a35c8203a87 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -830,9 +830,9 @@ def test_deepspeed_plugin_env_variables(mock_deepspeed_distributed, tmpdir, plat def _assert_save_model_is_equal(model, tmpdir, trainer): checkpoint_path = os.path.join(tmpdir, "model.pt") - checkpoint_path = trainer.accelerator.broadcast(checkpoint_path) + checkpoint_path = trainer.accelerator.training_type_plugin.broadcast(checkpoint_path) trainer.save_checkpoint(checkpoint_path) - trainer.accelerator.barrier() + trainer.accelerator.training_type_plugin.barrier() # carry out the check only on rank 0 if trainer.is_global_zero: diff --git a/tests/utilities/test_deepspeed_collate_checkpoint.py b/tests/utilities/test_deepspeed_collate_checkpoint.py index a04e56b7aabad3..b6074289d1b82f 100644 --- a/tests/utilities/test_deepspeed_collate_checkpoint.py +++ b/tests/utilities/test_deepspeed_collate_checkpoint.py @@ -31,9 +31,9 @@ def test_deepspeed_collate_checkpoint(tmpdir): ) trainer.fit(model) checkpoint_path = os.path.join(tmpdir, "model.pt") - checkpoint_path = trainer.accelerator.broadcast(checkpoint_path) + checkpoint_path = trainer.accelerator.training_type_plugin.broadcast(checkpoint_path) trainer.save_checkpoint(checkpoint_path) - trainer.accelerator.barrier() + trainer.accelerator.training_type_plugin.barrier() if trainer.is_global_zero: # ensure function call works output_path = os.path.join(tmpdir, "single_model.pt")