From 3a9fde915ad4c69620a6ccc411f5890cb38ba5ac Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:17:01 -0700 Subject: [PATCH] Revert "checkpoint consolidation" This reverts commit 536c1323b0e6715fb5919196ea48b0fcddddcd66. --- pytorch_lightning/callbacks/base.py | 4 --- pytorch_lightning/callbacks/early_stopping.py | 15 -------- .../callbacks/lambda_function.py | 3 -- .../callbacks/model_checkpoint.py | 31 ---------------- pytorch_lightning/trainer/callback_hook.py | 7 ---- .../callback_hook_validator.py | 5 --- pytorch_lightning/trainer/training_loop.py | 35 +++++++++++++++++-- tests/checkpointing/test_model_checkpoint.py | 35 ++++--------------- tests/helpers/utils.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 - 10 files changed, 39 insertions(+), 99 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index ffb26f38ca821..db507fa991446 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -109,10 +109,6 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the epoch ends.""" pass - def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: - """Called when at the very end of train epoch.""" - pass - def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 0de8ff6f0b505..4448de8e4834b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,21 +143,6 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): - from pytorch_lightning.trainer.states import TrainerState - if ( - trainer.state != TrainerState.FITTING or trainer.sanity_checking - or not trainer.checkpoint_connector.has_trained - ): - return - # if validation is disabled or should skip, we run early stopping - # at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self._run_early_stopping_check(trainer) - def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 2a56e1c8ac6e0..58324e363cd37 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,7 +53,6 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, - on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -156,5 +155,3 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad - if on_train_epoch_final_end is not None: - self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9436720e3819b..2a0c108ba7603 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,37 +238,6 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): - """ - at the end of each training epoch, checkpoint only when validation is skipped or disabled - """ - print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) - if ( - self._should_skip_saving_checkpoint(trainer) - or not trainer.checkpoint_connector.has_trained - ): - return - # if validation is disabled or should skip, we checkpoint at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self.save_checkpoint(trainer) - - def on_train_end(self, trainer, *args, **kwargs) -> None: - """ - checkpoints can be saved at the end of the trianing - """ - trainer.global_step -= 1 - if ( - not self._should_skip_saving_checkpoint(trainer) - and trainer.checkpoint_connector.has_trained - ): - if self.save_last and self.verbose: - rank_zero_info("Saving latest checkpoint...") - self.save_checkpoint(trainer) - trainer.global_step += 1 - def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index c53c21ad04bc3..8823d48a7817e 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,13 +92,6 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) - def on_train_epoch_final_end(self) -> None: - """ - Called when at the very end of train epoch. - """ - for callback in self.callbacks: - callback.on_train_epoch_final_end(self, self.lightning_module) - def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index e7884124df314..534dad5199e9b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,11 +100,6 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod - def _on_train_epoch_final_end_log(): - """Called when at the very end of train epoch.""" - return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1d498a0a9ff6c..c3ba34ca66d2d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,6 +121,12 @@ def on_train_end(self): return self._teardown_already_run = True + # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates + # when a checkpoint was saved at the last step + self.trainer.global_step -= 1 + self.check_checkpoint_callback(should_update=True, is_last=True) + self.trainer.global_step += 1 + # hook self.trainer.call_hook("on_train_end") @@ -139,6 +145,28 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None + def check_checkpoint_callback(self, should_update, is_last=False): + # TODO bake this logic into the ModelCheckpoint callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = self.trainer.checkpoint_callbacks + + if is_last and any(cb.save_last and cb.verbose for cb in callbacks): + rank_zero_info("Saving latest checkpoint...") + + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + + def check_early_stopping_callback(self, should_update): + # TODO bake this logic into the EarlyStopping callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -534,14 +562,15 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + if should_train_only: + self.check_checkpoint_callback(True) + self.check_early_stopping_callback(True) + if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True - if should_train_only: - self.trainer.call_hook('on_train_epoch_final_end') - # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index e0c295a843a21..75f25b90fa45f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,13 +609,7 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] - if period > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -637,14 +631,8 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] - if every_n_val_epochs > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -671,14 +659,8 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] - if every_n_val_epochs > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -834,15 +816,10 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, - val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - if verbose and save_last and not should_validate: - # no validation, hence checkpoint triggered at the end of each training epoch - assert caplog.messages.count('Saving latest checkpoint...') == False - else: - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 493d32d3fe454..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = "29501" + os.environ['MASTER_PORT'] = str(port) def init_checkpoint_callback(logger): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index b2727177bcacd..3db0a8eaa065b 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,7 +300,6 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', - 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end',