diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ab18a66d37f5..3d4eeddb5b6ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -291,6 +291,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed passing wrong strings for scheduler interval doesn't throw an error ([#5923](https://github.com/PyTorchLightning/pytorch-lightning/pull/5923)) +- Fixed wrong `requires_grad` state after `return None` with multiple optimizers ([#5738](https://github.com/PyTorchLightning/pytorch-lightning/pull/5638)) + + - Fixed add `on_epoch_end` hook at the end of `validation`, `test` epoch ([#5986](https://github.com/PyTorchLightning/pytorch-lightning/pull/5986)) @@ -303,6 +306,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed synchronization issues with TPU training ([#6027](https://github.com/PyTorchLightning/pytorch-lightning/pull/6027)) + ## [1.1.8] - 2021-02-08 ### Fixed diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 57c0b10f12412..6dc73b55ef53b 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -768,24 +768,23 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) self._curr_step_result = result - if result is None: - if self.automatic_optimization: - self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...") - return None - if not self._skip_backward and self.trainer.train_loop.automatic_optimization: # backward pass - with self.trainer.profiler.profile("model_backward"): - self.backward(result, optimizer, opt_idx) + if result is not None: + with self.trainer.profiler.profile("model_backward"): + self.backward(result, optimizer, opt_idx) - # hook - call this hook only - # when gradients have finished to accumulate - if not self.should_accumulate(): - self.on_after_backward(result.training_step_output, batch_idx, result.loss) + # hook - call this hook only + # when gradients have finished to accumulate + if not self.should_accumulate(): + self.on_after_backward(result.training_step_output, batch_idx, result.loss) - # check if loss or model weights are nan - if self.trainer.terminate_on_nan: - self.trainer.detect_nan_tensors(result.loss) + # check if loss or model weights are nan + if self.trainer.terminate_on_nan: + self.trainer.detect_nan_tensors(result.loss) + + else: + self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...") if len(self.trainer.optimizers) > 1: # revert back to previous state diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index a63f4107a63fe..17449b96d7cab 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -385,7 +385,9 @@ def optimizer_step( optimizer.step(closure=closure) def training_step(self, batch, batch_idx, optimizer_idx=None): - return super().training_step(batch, batch_idx) + loss = super().training_step(batch, batch_idx) + # make sure the model is untoggle when returning None + return loss if batch_idx % 2 == 0 else None @staticmethod def combine_generators(gen_1, gen_2):