Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add missing typing to trainer properties #5974

Merged
merged 12 commits into from
Feb 15, 2021
118 changes: 59 additions & 59 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@
import os
from abc import ABC
from argparse import ArgumentParser, Namespace
from typing import Any, cast, List, Optional, Type, TypeVar, Union
from typing import cast, List, Optional, Type, TypeVar, Union

import torch
from torch.optim import Optimizer

from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.accelerators.accelerator_connector import BackendConnector
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.plugins import PrecisionPlugin, TrainingTypePlugin
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, DeviceType, DistributedType, rank_zero_warn
Expand All @@ -46,84 +52,78 @@

class TrainerProperties(ABC):

precision: int
logger_connector: LoggerConnector
_state: TrainerState
global_rank: int
fast_dev_run: Union[int, bool]
_device_type: DeviceType
_distrib_type: DistributedType
model: LightningModule
data_parallel_device_ids: Optional[List[int]]
_progress_bar_callback: ProgressBarBase
limit_val_batches: int
_default_root_dir: str
_lightning_optimizers = None
_progress_bar_callback: ProgressBarBase
_state: TrainerState
_weights_save_path: str
accelerator_backend: Accelerator
num_nodes: int
num_processes: int

accelerator_connector: BackendConnector
_lightning_optimizers = None
callbacks: List[Callback]
checkpoint_connector: CheckpointConnector
limit_val_batches: int
logger: LightningLoggerBase
logger_connector: LoggerConnector

@property
def accelerator(self):
def accelerator(self) -> Accelerator:
return self.accelerator_connector.accelerator

@property
def accelerator_backend(self):
def accelerator_backend(self) -> Accelerator:
# for backward compatibility
return self.accelerator

@property
def distributed_backend(self):
def distributed_backend(self) -> Optional[str]:
# for backward compatibility
return self.accelerator_connector.distributed_backend

@property
def training_type_plugin(self):
def training_type_plugin(self) -> TrainingTypePlugin:
return self.accelerator.training_type_plugin

@property
def precision_plugin(self):
def precision_plugin(self) -> PrecisionPlugin:
return self.accelerator.precision_plugin

@property
def global_rank(self):
def global_rank(self) -> int:
return self.accelerator.training_type_plugin.global_rank

@property
def local_rank(self):
def local_rank(self) -> int:
# some training types define a local rank
return getattr(self.accelerator.training_type_plugin, "local_rank", 0)

@property
def node_rank(self):
def node_rank(self) -> int:
# some training types define a local rank
return getattr(self.accelerator.training_type_plugin, "node_rank", 0)

@property
def world_size(self):
def world_size(self) -> int:
# some training types define a world size
return getattr(self.accelerator.training_type_plugin, "world_size", 1)

@property
def _distrib_type(self):
def _distrib_type(self) -> DistributedType:
return self.accelerator_connector._distrib_type

@property
def _device_type(self):
def _device_type(self) -> DeviceType:
return self.accelerator_connector._device_type

@property
def num_nodes(self):
def num_nodes(self) -> int:
return self.accelerator_connector.num_nodes

@property
def num_processes(self):
def num_processes(self) -> int:
return self.accelerator_connector.num_processes

@property
def root_gpu(self):
def root_gpu(self) -> Optional[int]:
return self.accelerator_connector.root_gpu

@property
Expand All @@ -135,11 +135,11 @@ def num_gpus(self) -> int:
return self.accelerator_connector.num_gpus

@property
def data_parallel_device_ids(self):
def data_parallel_device_ids(self) -> Optional[List[int]]:
return self.accelerator_connector.parallel_device_ids

@property
def log_dir(self):
def log_dir(self) -> Optional[str]:
if self.logger is None:
dirpath = self.default_root_dir
else:
Expand All @@ -153,27 +153,27 @@ def use_amp(self) -> bool:
return self.precision == 16

@property
def callback_metrics(self):
def callback_metrics(self) -> dict:
return self.logger_connector.callback_metrics

@callback_metrics.setter
def callback_metrics(self, x):
def callback_metrics(self, x: dict) -> None:
self.logger_connector.callback_metrics = x

@property
def logged_metrics(self):
def logged_metrics(self) -> dict:
return self.logger_connector.logged_metrics

@logged_metrics.setter
def logged_metrics(self, x):
def logged_metrics(self, x: dict) -> None:
self.logger_connector.logged_metrics = x

@property
def progress_bar_metrics(self):
def progress_bar_metrics(self) -> dict:
return self.logger_connector.progress_bar_metrics

@progress_bar_metrics.setter
def progress_bar_metrics(self, x):
def progress_bar_metrics(self, x: dict) -> None:
self.logger_connector.progress_bar_metrics = x

@property
Expand All @@ -200,7 +200,7 @@ def slurm_job_id(self) -> Optional[int]:
return job_id

@classmethod
def default_attributes(cls):
def default_attributes(cls) -> dict:
init_signature = inspect.signature(cls)

args = {}
Expand Down Expand Up @@ -246,7 +246,7 @@ def data_parallel(self) -> bool:
)

@property
def progress_bar_callback(self):
def progress_bar_callback(self) -> Optional[ProgressBarBase]:
return self._progress_bar_callback

@property
Expand Down Expand Up @@ -335,11 +335,11 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]:
"""
return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]

def save_checkpoint(self, filepath, weights_only: bool = False):
def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
self.checkpoint_connector.save_checkpoint(filepath, weights_only)

@property
def model(self) -> Any:
def model(self) -> torch.nn.Module:
"""
The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel.
To access the pure LightningModule, use
Expand All @@ -348,7 +348,7 @@ def model(self) -> Any:
return self.accelerator.model

@model.setter
def model(self, model: torch.nn.Module):
def model(self, model: torch.nn.Module) -> None:
"""
Setter for the model, pass-through to accelerator and plugin where the model reference is stored.
Used by the Tuner to reset the state of Trainer and Accelerator.
Expand All @@ -359,51 +359,51 @@ def model(self, model: torch.nn.Module):
"""
self.accelerator.model = model

def get_model(self):
def get_model(self) -> LightningModule:
# TODO: rename this to lightning_module (see training type plugin)
# backward compatible
return self.lightning_module

@property
def lightning_optimizers(self):
def lightning_optimizers(self) -> List[LightningOptimizer]:
if self._lightning_optimizers is None:
self.convert_to_lightning_optimizers()
return self._lightning_optimizers

@property
def lightning_module(self):
def lightning_module(self) -> LightningModule:
return self.training_type_plugin.lightning_module

@property
def optimizers(self):
def optimizers(self) -> Optional[List[Optimizer]]:
return self.accelerator.optimizers

@optimizers.setter
def optimizers(self, new_optims):
def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None:
self.accelerator.optimizers = new_optims

@property
def lr_schedulers(self):
def lr_schedulers(self) -> Optional[list]:
return self.accelerator.lr_schedulers

@lr_schedulers.setter
def lr_schedulers(self, new_schedulers):
def lr_schedulers(self, new_schedulers: Optional[list]) -> None:
self.accelerator.lr_schedulers = new_schedulers

@property
def optimizer_frequencies(self):
def optimizer_frequencies(self) -> list:
return self.accelerator.optimizer_frequencies

@optimizer_frequencies.setter
def optimizer_frequencies(self, new_freqs):
def optimizer_frequencies(self, new_freqs: list) -> None:
self.accelerator.optimizer_frequencies = new_freqs

@property
def amp_backend(self):
def amp_backend(self) -> Optional[str]:
return self.accelerator.amp_backend

@property
def precision(self):
def precision(self) -> Union[str, int]:
return self.accelerator.precision

@property
Expand All @@ -420,16 +420,16 @@ def __setstate__(self, state):
self.__dict__ = state

@property
def require_distributed_sampler(self):
if self.accelerator_backend is not None:
return self.accelerator_backend.require_distributed_sampler
def require_distributed_sampler(self) -> bool:
if self.accelerator is not None:
return self.accelerator.require_distributed_sampler
return self._distrib_type in (
DistributedType.HOROVOD, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2
) or self._device_type == DeviceType.TPU

@property
def distributed_sampler_kwargs(self):
if self.accelerator_backend is not None:
def distributed_sampler_kwargs(self) -> dict:
if self.accelerator is not None:
return self.training_type_plugin.distributed_sampler_kwargs

# TODO: make sure the cases below are handled by the training_type_plugin
Expand Down