diff --git a/CHANGELOG.md b/CHANGELOG.md index 41ec984da3c88..ed7eec7cff7f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -75,6 +75,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207)) +- Removed `optimizer_idx` argument from `training_step` in manual optimization ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093)) + + ### Fixed - Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)) diff --git a/README.md b/README.md index 08bba6c2db6c7..dd30bd94d78a6 100644 --- a/README.md +++ b/README.md @@ -318,9 +318,9 @@ class LitAutoEncoder(pl.LightningModule): super().__init__() self.automatic_optimization = False - def training_step(self, batch, batch_idx, optimizer_idx): + def training_step(self, batch, batch_idx): # access your optimizers with use_pl_optimizer=False. Default is True - (opt_a, opt_b) = self.optimizers(use_pl_optimizer=True) + opt_a, opt_b = self.optimizers(use_pl_optimizer=True) loss_a = ... self.manual_backward(loss_a, opt_a) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index c02f23ac60d09..ec257bf444f5c 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -952,14 +952,12 @@ When set to ``False``, Lightning does not automate the optimization process. Thi This is recommended only if using 2+ optimizers AND if you know how to perform the optimization procedure properly. Note that automatic optimization can still be used with multiple optimizers by relying on the ``optimizer_idx`` parameter. Manual optimization is most useful for research topics like reinforcement learning, sparse coding, and GAN research. -In the multi-optimizer case, ignore the ``optimizer_idx`` argument and use the optimizers directly - .. code-block:: python def __init__(self): self.automatic_optimization = False - def training_step(self, batch, batch_idx, optimizer_idx): + def training_step(self, batch, batch_idx): # access your optimizers with use_pl_optimizer=False. Default is True opt_a, opt_b = self.optimizers(use_pl_optimizer=True) diff --git a/docs/source/common/optimizers.rst b/docs/source/common/optimizers.rst index 3b29fd4c08f13..10813a94c35d2 100644 --- a/docs/source/common/optimizers.rst +++ b/docs/source/common/optimizers.rst @@ -51,7 +51,10 @@ to manually manage the optimization process. To do so, do the following: .. code-block:: python - def training_step(batch, batch_idx): + def __init__(self): + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): opt = self.optimizers() loss = self.compute_loss(batch) @@ -69,7 +72,10 @@ Here is the same example as above using a ``closure``. .. testcode:: python - def training_step(batch, batch_idx): + def __init__(self): + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): opt = self.optimizers() def forward_and_backward(): @@ -126,7 +132,6 @@ Here is the same example as above using a ``closure``. # Optimize Discriminator # ########################### d_opt.zero_grad() - d_x = self.D(X) errD_real = self.criterion(d_x, real_label) @@ -179,6 +184,9 @@ Here is an example for advanced use-case. ... + def __init__(self): + self.automatic_optimization = False + def training_step(self, batch, batch_idx): # Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html g_opt, d_opt = self.optimizers() diff --git a/docs/source/starter/new-project.rst b/docs/source/starter/new-project.rst index 0f1362616a9b1..553f312afe9b9 100644 --- a/docs/source/starter/new-project.rst +++ b/docs/source/starter/new-project.rst @@ -265,7 +265,7 @@ Turn off automatic optimization and you control the train loop! def __init__(self): self.automatic_optimization = False - def training_step(self, batch, batch_idx, optimizer_idx): + def training_step(self, batch, batch_idx): # access your optimizers with use_pl_optimizer=False. Default is True opt_a, opt_b = self.optimizers(use_pl_optimizer=True) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 57ae9557139e0..7ad035e1f56b4 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -310,7 +310,7 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): closure_loss = None untouched_loss = None - if self.trainer.train_loop.automatic_optimization: + if self.automatic_optimization: # accumulate loss # (if accumulate_grad_batches = 1 no effect) if is_result_obj: @@ -840,12 +840,17 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): if len(self.trainer.optimizers) > 1: if self.trainer.has_arg("training_step", "optimizer_idx"): + if not self.automatic_optimization: + self.warning_cache.warn( + "`training_step` hook signature has changed in v1.3." + " `optimizer_idx` argument has been removed in case of manual optimization. Support for" + " the old signature will be removed in v1.5", DeprecationWarning + ) args.append(opt_idx) - else: - num_opts = len(self.trainer.optimizers) + elif not self.trainer.has_arg("training_step", "optimizer_idx") and self.automatic_optimization: raise ValueError( - f"Your LightningModule defines {num_opts} optimizers but " - f'training_step is missing the "optimizer_idx" argument.' + f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but" + ' `training_step` is missing the `optimizer_idx` argument.' ) # pass hiddens if using tbptt diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index 3c6e34df8d5e3..8858129b221f9 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -89,8 +89,9 @@ def __init__(self): super().__init__() self.automatic_optimization = False - def training_step(self, batch, batch_idx, optimizer_idx=None): + def training_step(self, batch, batch_idx): opt_1, opt_2 = self.optimizers() + assert isinstance(opt_1, LightningOptimizer) assert isinstance(opt_2, LightningOptimizer) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 6385e02af33a6..cc9bcc9d56c06 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -16,6 +16,7 @@ from unittest import mock import pytest +from torch import optim from pytorch_lightning import Callback, Trainer from pytorch_lightning.loggers import WandbLogger @@ -74,3 +75,25 @@ def test_v1_5_0_running_sanity_check(): trainer = Trainer() with pytest.deprecated_call(match='has been renamed to `Trainer.sanity_checking`'): assert not trainer.running_sanity_check + + +def test_old_training_step_signature_with_opt_idx_manual_opt(tmpdir): + + class OldSignatureModel(BoringModel): + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx, optimizer_idx): + assert optimizer_idx is not None + return super().training_step(batch, batch_idx) + + def configure_optimizers(self): + return [optim.SGD(self.parameters(), lr=1e-2), optim.SGD(self.parameters(), lr=1e-2)] + + model = OldSignatureModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2) + + with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"): + trainer.fit(model) diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 0fa8906106cb1..e197e9b35adc9 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -40,9 +40,9 @@ def __init__(self): super().__init__() self.automatic_optimization = False - def training_step(self, batch, batch_idx, optimizer_idx): + def training_step(self, batch, batch_idx): # manual - (opt_a, opt_b) = self.optimizers() + opt_a, opt_b = self.optimizers() loss_1 = self.step(batch[0]) # make sure there are no grads @@ -107,9 +107,9 @@ def __init__(self): super().__init__() self.automatic_optimization = False - def training_step(self, batch, batch_idx, optimizer_idx): + def training_step(self, batch, batch_idx): # manual - (opt_a, opt_b) = self.optimizers() + opt_a, opt_b = self.optimizers() loss_1 = self.step(batch[0]) # make sure there are no grads @@ -176,9 +176,9 @@ def __init__(self): super().__init__() self.automatic_optimization = False - def training_step(self, batch, batch_idx, optimizer_idx): + def training_step(self, batch, batch_idx): # manual - (opt_a, opt_b) = self.optimizers() + opt_a, opt_b = self.optimizers() loss_1 = self.step(batch[0]) # make sure there are no grads @@ -251,9 +251,9 @@ def __init__(self): super().__init__() self.automatic_optimization = False - def training_step(self, batch, batch_idx, optimizer_idx): + def training_step(self, batch, batch_idx): # manual - (opt_a, opt_b) = self.optimizers() + opt_a, opt_b = self.optimizers() loss_1 = self.step(batch[0]) # make sure there are no grads @@ -321,9 +321,9 @@ def __init__(self): super().__init__() self.automatic_optimization = False - def training_step(self, batch, batch_idx, optimizer_idx): + def training_step(self, batch, batch_idx): # manual - (opt_a, opt_b) = self.optimizers() + opt_a, opt_b = self.optimizers() x = batch[0] loss_1 = self(x) @@ -610,9 +610,9 @@ def on_after_backward(self): if not (torch.isinf(norm) or torch.isnan(norm)): assert norm.item() < 100, norm.item() - def training_step(self, batch, batch_idx, optimizer_idx): + def training_step(self, batch, batch_idx): # manual - (opt_a, opt_b) = self.optimizers() + opt_a, opt_b = self.optimizers() x = batch[0] loss_1 = self(x) @@ -886,7 +886,7 @@ def __init__(self): super().__init__() self.automatic_optimization = False - def training_step(self, batch, batch_idx, optimizer_idx): + def training_step(self, batch, batch_idx): # emulate gans training opt_gen, opt_dis = self.optimizers() @@ -981,7 +981,7 @@ def manual_sync_grad(self) -> bool: torch_distrib.all_reduce(self.layer.weight.grad.data, async_op=False) return True - def training_step(self, batch, batch_idx, optimizer_idx): + def training_step(self, batch, batch_idx): # emulate gans training opt_gen, opt_dis = self.optimizers() @@ -1088,9 +1088,9 @@ def test_step_with_optimizer_closure_with_different_frequencies_ddp_spawn(tmpdir train_manual_optimization(tmpdir, "ddp_spawn") -class TesManualOptimizationDDPModelToggleModel(TesManualOptimizationDDPModel): +class TestManualOptimizationDDPModelToggleModel(TesManualOptimizationDDPModel): - def training_step(self, batch, batch_idx, optimizer_idx): + def training_step(self, batch, batch_idx): # emulate gans training opt_gen, opt_dis = self.optimizers() @@ -1147,4 +1147,4 @@ def dis_closure(): @RunIf(min_gpus=2, special=True) def test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model(tmpdir): - train_manual_optimization(tmpdir, "ddp", model_cls=TesManualOptimizationDDPModelToggleModel) + train_manual_optimization(tmpdir, "ddp", model_cls=TestManualOptimizationDDPModelToggleModel) diff --git a/tests/trainer/optimization/test_multiple_optimizers.py b/tests/trainer/optimization/test_multiple_optimizers.py index 84fdeab2c1311..5f0ca34015df0 100644 --- a/tests/trainer/optimization/test_multiple_optimizers.py +++ b/tests/trainer/optimization/test_multiple_optimizers.py @@ -14,6 +14,7 @@ """ Tests to ensure that the behaviours related to multiple optimizers works """ +import pytest import torch import pytorch_lightning as pl @@ -90,11 +91,6 @@ def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer assert len(outputs) == 2 - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) - optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) - return optimizer, optimizer_2 - model = TestModel() model.val_dataloader = None @@ -119,7 +115,7 @@ def __init__(self): super().__init__() self.automatic_optimization = False - def training_step(self, batch, batch_idx, optimizer_idx): + def training_step(self, batch, batch_idx): self.training_step_called = True # manual optimization @@ -154,3 +150,20 @@ def training_epoch_end(self, outputs) -> None: trainer.fit(model) assert model.training_step_called + + +def test_multiple_optimizers_no_opt_idx_argument(tmpdir): + """ + Test that an error is raised if no optimizer_idx is present when + multiple optimizeres are passed in case of automatic_optimization + """ + + class TestModel(MultiOptModel): + + def training_step(self, batch, batch_idx): + return super().training_step(batch, batch_idx) + + trainer = pl.Trainer(default_root_dir=tmpdir, fast_dev_run=2) + + with pytest.raises(ValueError, match='`training_step` is missing the `optimizer_idx`'): + trainer.fit(TestModel())