Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TrainerState refactor [5/5] #7173

Merged
merged 19 commits into from
May 4, 2021
Merged
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
- Added `TrainerFn.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945), [#7173](https://github.com/PyTorchLightning/pytorch-lightning/pull/7173))


- Added `TrainerStatus.{INITIALIZING,RUNNING,FINISHED,INTERRUPTED}` ([#7173](https://github.com/PyTorchLightning/pytorch-lightning/pull/7173))


- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ([#4948](https://github.com/PyTorchLightning/pytorch-lightning/pull/4948))
Expand Down Expand Up @@ -157,7 +160,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))


- Refactor `RunningStage` and `TrainerState` usage ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
- Refactor `RunningStage` and `TrainerState` usage ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945), [#7173](https://github.com/PyTorchLightning/pytorch-lightning/pull/7173))


- Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import pytorch_lightning as pl
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
Expand Down Expand Up @@ -374,7 +374,7 @@ def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
Args:
trainer: the Trainer, these optimizers should be connected to
"""
if trainer.state not in (TrainerState.FITTING, TrainerState.TUNING):
if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
return
optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers(
trainer=trainer, model=self.lightning_module
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None:
self.patience = callback_state['patience']

def _should_skip_check(self, trainer) -> bool:
from pytorch_lightning.trainer.states import TrainerState
return trainer.state != TrainerState.FITTING or trainer.sanity_checking
from pytorch_lightning.trainer.states import TrainerFn
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking

def on_train_epoch_end(self, trainer, pl_module, outputs) -> None:
if not self._check_on_train_epoch_end or self._should_skip_check(trainer):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,10 @@ def save_checkpoint(self, trainer, unused: Optional = None):
self._save_last_checkpoint(trainer, monitor_candidates)

def _should_skip_saving_checkpoint(self, trainer) -> bool:
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
return (
trainer.fast_dev_run # disable checkpointing with fast_dev_run
or trainer.state != TrainerState.FITTING # don't save anything during non-fit
or trainer.state.fn != TrainerFn.FITTING # don't save anything during non-fit
or trainer.sanity_checking # don't save anything during sanity check
or self._last_global_step_saved == trainer.global_step # already saved at the last step
)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

def wrap_qat_forward_context(
quant_cb,
model: pl.core.LightningModule,
model: 'pl.LightningModule',
func: Callable,
trigger_condition: Optional[Union[Callable, int]] = None
) -> Callable:
Expand All @@ -57,7 +57,7 @@ def wrapper(data) -> Any:
return wrapper


def wrap_quantize_forward_context(model: pl.core.LightningModule, func: Callable) -> Callable:
def wrap_quantize_forward_context(model: 'pl.LightningModule', func: Callable) -> Callable:
"""
Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out compatibility
"""
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
Expand Down Expand Up @@ -245,7 +245,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# save the last weights
last_path = None
if (
self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None
self.lightning_module.trainer.state.fn == TrainerFn.FITTING and best_model_path is not None
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
Expand All @@ -263,7 +263,7 @@ def __recover_child_process_weights(self, best_path, last_path):
# todo, pass also best score

# load last weights
if last_path is not None and self.lightning_module.trainer.state == TrainerState.FITTING:
if last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
ckpt = pl_load(last_path, map_location=lambda storage, loc: storage)
self.lightning_module.load_state_dict(ckpt)

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,16 +510,16 @@ def restore_model_state_from_ckpt_path(
) -> Tuple[Dict, bool]:
if not self.save_full_weights and self.world_size > 1:
# Rely on deepspeed to load the checkpoint and necessary information
from pytorch_lightning.trainer.states import TrainerState
stage_is_fit = self.lightning_module.trainer.state == TrainerState.FITTING
from pytorch_lightning.trainer.states import TrainerFn
is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING
save_dir = self._filepath_to_dir(ckpt_path)

if self.zero_stage_3:
# TODO: Currently required as this call is missing within the deepspeed engine.
self.deepspeed_engine.optimizer._partition_all_parameters()

_, client_state = self.deepspeed_engine.load_checkpoint(
save_dir, load_optimizer_states=stage_is_fit, load_lr_scheduler_states=stage_is_fit
save_dir, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting
)

# restore datamodule states
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/plugins/training_type/rpc_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.distributed import LightningDistributedModule
from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -208,7 +208,7 @@ def _skip_init_connections(self):
Returns: Whether to skip initialization

"""
return torch_distrib.is_initialized() and self.lightning_module.trainer.state != TrainerState.FITTING
return torch_distrib.is_initialized() and self.lightning_module.trainer.state.fn != TrainerFn.FITTING

def init_model_parallel_groups(self):
num_model_parallel = 1 # TODO currently no support for vertical model parallel
Expand All @@ -231,7 +231,7 @@ def _infer_check_num_gpus(self):
return self.world_size

def handle_transferred_pipe_module(self) -> None:
if self.lightning_module.trainer.state == TrainerState.FITTING:
if self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
torch_distrib.barrier() # Ensure we await main process initialization
# Add trainer/configure_optimizers to the pipe model for access in all worker processes
rpc_pipe.PipeModel.trainer = self.lightning_module.trainer
Expand All @@ -243,7 +243,7 @@ def init_pipe_module(self) -> None:
# Create pipe_module
model = self.lightning_module
self._find_and_init_pipe_module(model)
if self.lightning_module.trainer.state == TrainerState.FITTING:
if self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
torch_distrib.barrier() # Ensure we join main process initialization
model.sequential_module.foreach_worker(register_optimizers, include_self=True)

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -66,7 +66,7 @@ def _reinit_optimizers_with_oss(self):
trainer.convert_to_lightning_optimizers()

def _wrap_optimizers(self):
if self.model.trainer.state != TrainerState.FITTING:
if self.model.trainer.state.fn != TrainerFn.FITTING:
return
self._reinit_optimizers_with_oss()

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -53,7 +53,7 @@ def _reinit_optimizers_with_oss(self):
trainer.optimizers = optimizers

def _wrap_optimizers(self):
if self.model.trainer.state != TrainerState.FITTING:
if self.model.trainer.state.fn != TrainerFn.FITTING:
return
self._reinit_optimizers_with_oss()

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.data import has_len
Expand Down Expand Up @@ -190,7 +190,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# save the last weights
last_path = None
if (
self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None
self.lightning_module.trainer.state.fn == TrainerFn.FITTING and best_model_path is not None
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytorch_lightning as pl
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand All @@ -31,14 +31,14 @@ def verify_loop_configurations(self, model: 'pl.LightningModule') -> None:
model: The model to check the configuration.

"""
if self.trainer.state in (TrainerState.FITTING, TrainerState.TUNING):
if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
self.__verify_train_loop_configuration(model)
self.__verify_eval_loop_configuration(model, 'val')
elif self.trainer.state == TrainerState.VALIDATING:
elif self.trainer.state.fn == TrainerFn.VALIDATING:
self.__verify_eval_loop_configuration(model, 'val')
elif self.trainer.state == TrainerState.TESTING:
elif self.trainer.state.fn == TrainerFn.TESTING:
self.__verify_eval_loop_configuration(model, 'test')
elif self.trainer.state == TrainerState.PREDICTING:
elif self.trainer.state.fn == TrainerFn.PREDICTING:
self.__verify_predict_loop_configuration(model)
self.__verify_dp_batch_transfer_support(model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import pytorch_lightning as pl
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import DistributedType, LightningEnum


Expand Down Expand Up @@ -355,7 +355,7 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]:

# TODO(carmocca): when we implement flushing the logger connector metrics after
# the trainer.state changes, this should check trainer.evaluating instead
if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING):
if self.trainer.state.fn in (TrainerFn.TESTING, TrainerFn.VALIDATING):
logger_connector.evaluation_callback_metrics.update(callback_metrics)

# update callback_metrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
from pytorch_lightning.trainer.states import RunningStage, TrainerState
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities import DeviceType
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT
Expand Down Expand Up @@ -78,7 +78,7 @@ def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None:

@property
def cached_results(self) -> Union[EpochResultStore, None]:
return self._cached_results.get(self.trainer._running_stage)
return self._cached_results.get(self.trainer.state.stage)

def get_metrics(self, key: str) -> Dict:
metrics_holder: MetricsHolder = getattr(self, f"_{key}")
Expand Down Expand Up @@ -116,7 +116,7 @@ def on_train_batch_end(self) -> None:
self.cached_results._batch_size = None

def cache_logged_metrics(self):
self._cached_results[self.trainer._running_stage].cache_result()
self._cached_results[self.trainer.state.stage].cache_result()

def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool):
# logging
Expand Down Expand Up @@ -279,12 +279,12 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT:

# log results of evaluation
if (
self.trainer.state != TrainerState.FITTING and self.trainer.evaluating and self.trainer.is_global_zero
self.trainer.state.fn != TrainerFn.FITTING and self.trainer.evaluating and self.trainer.is_global_zero
and self.trainer.verbose_evaluate
):
print('-' * 80)
for result_idx, results in enumerate(self.eval_loop_results):
print(f'DATALOADER:{result_idx} {self.trainer._running_stage.upper()} RESULTS')
print(f'DATALOADER:{result_idx} {self.trainer.state.stage.upper()} RESULTS')
pprint({
k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v
for k, v in results.items()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ def setup(self) -> None:
trainer = self.trainer
local_rank = trainer.local_rank if trainer.world_size > 1 else None
trainer.profiler._lightning_module = proxy(trainer.lightning_module)
trainer.profiler.setup(stage=trainer._setup_state, local_rank=local_rank, log_dir=trainer.log_dir)
trainer.profiler.setup(stage=trainer.state.fn._setup_fn, local_rank=local_rank, log_dir=trainer.log_dir)
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def _reset_eval_dataloader(

# add samplers
dataloaders = [
self.auto_add_sampler(dl, shuffle=False, mode=self._running_stage) for dl in dataloaders if dl is not None
self.auto_add_sampler(dl, shuffle=False, mode=self.state.stage) for dl in dataloaders if dl is not None
]

# add worker_init_fn for correct seeding in worker processes
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pytorch_lightning as pl
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
Expand Down Expand Up @@ -101,7 +101,7 @@ def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
else:
self.trainer.call_hook('on_validation_end', *args, **kwargs)

if self.trainer.state != TrainerState.FITTING:
if self.trainer.state.fn != TrainerFn.FITTING:
# summarize profile results
self.trainer.profiler.describe()

Expand Down
Loading