Skip to content

Commit

Permalink
TrainerState refactor [5/5] (#7173)
Browse files Browse the repository at this point in the history
* `TrainerState` refactor

* flake8

* Update finished check

* Test cleanup

* Fix tests

* Fixes

* Reorder

* flake8

* Update CHANGELOG

* Better docs

* Better docs

* Remove default

* Update tests

* Bad merge
  • Loading branch information
carmocca authored May 4, 2021
1 parent a6aa1a0 commit 8c0ea92
Show file tree
Hide file tree
Showing 50 changed files with 295 additions and 304 deletions.
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
- Added `TrainerFn.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945), [#7173](https://github.com/PyTorchLightning/pytorch-lightning/pull/7173))


- Added `TrainerStatus.{INITIALIZING,RUNNING,FINISHED,INTERRUPTED}` ([#7173](https://github.com/PyTorchLightning/pytorch-lightning/pull/7173))


- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ([#4948](https://github.com/PyTorchLightning/pytorch-lightning/pull/4948))
Expand Down Expand Up @@ -157,7 +160,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))


- Refactor `RunningStage` and `TrainerState` usage ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
- Refactor `RunningStage` and `TrainerState` usage ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945), [#7173](https://github.com/PyTorchLightning/pytorch-lightning/pull/7173))


- Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import pytorch_lightning as pl
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
Expand Down Expand Up @@ -374,7 +374,7 @@ def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
Args:
trainer: the Trainer, these optimizers should be connected to
"""
if trainer.state not in (TrainerState.FITTING, TrainerState.TUNING):
if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
return
optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers(
trainer=trainer, model=self.lightning_module
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None:
self.patience = callback_state['patience']

def _should_skip_check(self, trainer) -> bool:
from pytorch_lightning.trainer.states import TrainerState
return trainer.state != TrainerState.FITTING or trainer.sanity_checking
from pytorch_lightning.trainer.states import TrainerFn
return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking

def on_train_epoch_end(self, trainer, pl_module, outputs) -> None:
if not self._check_on_train_epoch_end or self._should_skip_check(trainer):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,10 @@ def save_checkpoint(self, trainer, unused: Optional = None):
self._save_last_checkpoint(trainer, monitor_candidates)

def _should_skip_saving_checkpoint(self, trainer) -> bool:
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
return (
trainer.fast_dev_run # disable checkpointing with fast_dev_run
or trainer.state != TrainerState.FITTING # don't save anything during non-fit
or trainer.state.fn != TrainerFn.FITTING # don't save anything during non-fit
or trainer.sanity_checking # don't save anything during sanity check
or self._last_global_step_saved == trainer.global_step # already saved at the last step
)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

def wrap_qat_forward_context(
quant_cb,
model: pl.core.LightningModule,
model: 'pl.LightningModule',
func: Callable,
trigger_condition: Optional[Union[Callable, int]] = None
) -> Callable:
Expand All @@ -57,7 +57,7 @@ def wrapper(data) -> Any:
return wrapper


def wrap_quantize_forward_context(model: pl.core.LightningModule, func: Callable) -> Callable:
def wrap_quantize_forward_context(model: 'pl.LightningModule', func: Callable) -> Callable:
"""
Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out compatibility
"""
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
Expand Down Expand Up @@ -245,7 +245,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# save the last weights
last_path = None
if (
self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None
self.lightning_module.trainer.state.fn == TrainerFn.FITTING and best_model_path is not None
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
Expand All @@ -263,7 +263,7 @@ def __recover_child_process_weights(self, best_path, last_path):
# todo, pass also best score

# load last weights
if last_path is not None and self.lightning_module.trainer.state == TrainerState.FITTING:
if last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
ckpt = pl_load(last_path, map_location=lambda storage, loc: storage)
self.lightning_module.load_state_dict(ckpt)

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,16 +510,16 @@ def restore_model_state_from_ckpt_path(
) -> Tuple[Dict, bool]:
if not self.save_full_weights and self.world_size > 1:
# Rely on deepspeed to load the checkpoint and necessary information
from pytorch_lightning.trainer.states import TrainerState
stage_is_fit = self.lightning_module.trainer.state == TrainerState.FITTING
from pytorch_lightning.trainer.states import TrainerFn
is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING
save_dir = self._filepath_to_dir(ckpt_path)

if self.zero_stage_3:
# TODO: Currently required as this call is missing within the deepspeed engine.
self.deepspeed_engine.optimizer._partition_all_parameters()

_, client_state = self.deepspeed_engine.load_checkpoint(
save_dir, load_optimizer_states=stage_is_fit, load_lr_scheduler_states=stage_is_fit
save_dir, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting
)

# restore datamodule states
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/plugins/training_type/rpc_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.distributed import LightningDistributedModule
from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -208,7 +208,7 @@ def _skip_init_connections(self):
Returns: Whether to skip initialization
"""
return torch_distrib.is_initialized() and self.lightning_module.trainer.state != TrainerState.FITTING
return torch_distrib.is_initialized() and self.lightning_module.trainer.state.fn != TrainerFn.FITTING

def init_model_parallel_groups(self):
num_model_parallel = 1 # TODO currently no support for vertical model parallel
Expand All @@ -231,7 +231,7 @@ def _infer_check_num_gpus(self):
return self.world_size

def handle_transferred_pipe_module(self) -> None:
if self.lightning_module.trainer.state == TrainerState.FITTING:
if self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
torch_distrib.barrier() # Ensure we await main process initialization
# Add trainer/configure_optimizers to the pipe model for access in all worker processes
rpc_pipe.PipeModel.trainer = self.lightning_module.trainer
Expand All @@ -243,7 +243,7 @@ def init_pipe_module(self) -> None:
# Create pipe_module
model = self.lightning_module
self._find_and_init_pipe_module(model)
if self.lightning_module.trainer.state == TrainerState.FITTING:
if self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
torch_distrib.barrier() # Ensure we join main process initialization
model.sequential_module.foreach_worker(register_optimizers, include_self=True)

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -66,7 +66,7 @@ def _reinit_optimizers_with_oss(self):
trainer.convert_to_lightning_optimizers()

def _wrap_optimizers(self):
if self.model.trainer.state != TrainerState.FITTING:
if self.model.trainer.state.fn != TrainerFn.FITTING:
return
self._reinit_optimizers_with_oss()

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -53,7 +53,7 @@ def _reinit_optimizers_with_oss(self):
trainer.optimizers = optimizers

def _wrap_optimizers(self):
if self.model.trainer.state != TrainerState.FITTING:
if self.model.trainer.state.fn != TrainerFn.FITTING:
return
self._reinit_optimizers_with_oss()

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.data import has_len
Expand Down Expand Up @@ -190,7 +190,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# save the last weights
last_path = None
if (
self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None
self.lightning_module.trainer.state.fn == TrainerFn.FITTING and best_model_path is not None
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytorch_lightning as pl
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand All @@ -31,14 +31,14 @@ def verify_loop_configurations(self, model: 'pl.LightningModule') -> None:
model: The model to check the configuration.
"""
if self.trainer.state in (TrainerState.FITTING, TrainerState.TUNING):
if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
self.__verify_train_loop_configuration(model)
self.__verify_eval_loop_configuration(model, 'val')
elif self.trainer.state == TrainerState.VALIDATING:
elif self.trainer.state.fn == TrainerFn.VALIDATING:
self.__verify_eval_loop_configuration(model, 'val')
elif self.trainer.state == TrainerState.TESTING:
elif self.trainer.state.fn == TrainerFn.TESTING:
self.__verify_eval_loop_configuration(model, 'test')
elif self.trainer.state == TrainerState.PREDICTING:
elif self.trainer.state.fn == TrainerFn.PREDICTING:
self.__verify_predict_loop_configuration(model)
self.__verify_dp_batch_transfer_support(model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import pytorch_lightning as pl
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import DistributedType, LightningEnum


Expand Down Expand Up @@ -355,7 +355,7 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]:

# TODO(carmocca): when we implement flushing the logger connector metrics after
# the trainer.state changes, this should check trainer.evaluating instead
if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING):
if self.trainer.state.fn in (TrainerFn.TESTING, TrainerFn.VALIDATING):
logger_connector.evaluation_callback_metrics.update(callback_metrics)

# update callback_metrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
from pytorch_lightning.trainer.states import RunningStage, TrainerState
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities import DeviceType
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT
Expand Down Expand Up @@ -78,7 +78,7 @@ def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None:

@property
def cached_results(self) -> Union[EpochResultStore, None]:
return self._cached_results.get(self.trainer._running_stage)
return self._cached_results.get(self.trainer.state.stage)

def get_metrics(self, key: str) -> Dict:
metrics_holder: MetricsHolder = getattr(self, f"_{key}")
Expand Down Expand Up @@ -116,7 +116,7 @@ def on_train_batch_end(self) -> None:
self.cached_results._batch_size = None

def cache_logged_metrics(self):
self._cached_results[self.trainer._running_stage].cache_result()
self._cached_results[self.trainer.state.stage].cache_result()

def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool):
# logging
Expand Down Expand Up @@ -279,12 +279,12 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT:

# log results of evaluation
if (
self.trainer.state != TrainerState.FITTING and self.trainer.evaluating and self.trainer.is_global_zero
self.trainer.state.fn != TrainerFn.FITTING and self.trainer.evaluating and self.trainer.is_global_zero
and self.trainer.verbose_evaluate
):
print('-' * 80)
for result_idx, results in enumerate(self.eval_loop_results):
print(f'DATALOADER:{result_idx} {self.trainer._running_stage.upper()} RESULTS')
print(f'DATALOADER:{result_idx} {self.trainer.state.stage.upper()} RESULTS')
pprint({
k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v
for k, v in results.items()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/connectors/profiler_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ def setup(self) -> None:
trainer = self.trainer
local_rank = trainer.local_rank if trainer.world_size > 1 else None
trainer.profiler._lightning_module = proxy(trainer.lightning_module)
trainer.profiler.setup(stage=trainer._setup_state, local_rank=local_rank, log_dir=trainer.log_dir)
trainer.profiler.setup(stage=trainer.state.fn._setup_fn, local_rank=local_rank, log_dir=trainer.log_dir)
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def _reset_eval_dataloader(

# add samplers
dataloaders = [
self.auto_add_sampler(dl, shuffle=False, mode=self._running_stage) for dl in dataloaders if dl is not None
self.auto_add_sampler(dl, shuffle=False, mode=self.state.stage) for dl in dataloaders if dl is not None
]

# add worker_init_fn for correct seeding in worker processes
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pytorch_lightning as pl
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
Expand Down Expand Up @@ -101,7 +101,7 @@ def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
else:
self.trainer.call_hook('on_validation_end', *args, **kwargs)

if self.trainer.state != TrainerState.FITTING:
if self.trainer.state.fn != TrainerFn.FITTING:
# summarize profile results
self.trainer.profiler.describe()

Expand Down
Loading

0 comments on commit 8c0ea92

Please sign in to comment.