Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove optimizer_idx arg in manual optimization #6093

Merged
merged 7 commits into from
Mar 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207))


- Removed `optimizer_idx` argument from `training_step` in manual optimization ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))


### Fixed

- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ class LitAutoEncoder(pl.LightningModule):
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx, optimizer_idx):
def training_step(self, batch, batch_idx):
# access your optimizers with use_pl_optimizer=False. Default is True
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
(opt_a, opt_b) = self.optimizers(use_pl_optimizer=True)
opt_a, opt_b = self.optimizers(use_pl_optimizer=True)

loss_a = ...
self.manual_backward(loss_a, opt_a)
Expand Down
4 changes: 1 addition & 3 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -952,14 +952,12 @@ When set to ``False``, Lightning does not automate the optimization process. Thi

This is recommended only if using 2+ optimizers AND if you know how to perform the optimization procedure properly. Note that automatic optimization can still be used with multiple optimizers by relying on the ``optimizer_idx`` parameter. Manual optimization is most useful for research topics like reinforcement learning, sparse coding, and GAN research.

In the multi-optimizer case, ignore the ``optimizer_idx`` argument and use the optimizers directly

.. code-block:: python

def __init__(self):
self.automatic_optimization = False

def training_step(self, batch, batch_idx, optimizer_idx):
def training_step(self, batch, batch_idx):
# access your optimizers with use_pl_optimizer=False. Default is True
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
opt_a, opt_b = self.optimizers(use_pl_optimizer=True)

Expand Down
14 changes: 11 additions & 3 deletions docs/source/common/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ to manually manage the optimization process. To do so, do the following:

.. code-block:: python

def training_step(batch, batch_idx):
def __init__(self):
self.automatic_optimization = False

def training_step(self, batch, batch_idx):
opt = self.optimizers()

loss = self.compute_loss(batch)
Expand All @@ -69,7 +72,10 @@ Here is the same example as above using a ``closure``.

.. testcode:: python

def training_step(batch, batch_idx):
def __init__(self):
self.automatic_optimization = False

def training_step(self, batch, batch_idx):
opt = self.optimizers()

def forward_and_backward():
Expand Down Expand Up @@ -126,7 +132,6 @@ Here is the same example as above using a ``closure``.
# Optimize Discriminator #
###########################
d_opt.zero_grad()

d_x = self.D(X)
errD_real = self.criterion(d_x, real_label)

Expand Down Expand Up @@ -179,6 +184,9 @@ Here is an example for advanced use-case.

...

def __init__(self):
self.automatic_optimization = False

def training_step(self, batch, batch_idx):
# Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
g_opt, d_opt = self.optimizers()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/starter/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ Turn off automatic optimization and you control the train loop!
def __init__(self):
self.automatic_optimization = False

def training_step(self, batch, batch_idx, optimizer_idx):
def training_step(self, batch, batch_idx):
# access your optimizers with use_pl_optimizer=False. Default is True
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
opt_a, opt_b = self.optimizers(use_pl_optimizer=True)

Expand Down
15 changes: 10 additions & 5 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
closure_loss = None
untouched_loss = None

if self.trainer.train_loop.automatic_optimization:
if self.automatic_optimization:
# accumulate loss
# (if accumulate_grad_batches = 1 no effect)
if is_result_obj:
Expand Down Expand Up @@ -840,12 +840,17 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens):

if len(self.trainer.optimizers) > 1:
if self.trainer.has_arg("training_step", "optimizer_idx"):
if not self.automatic_optimization:
self.warning_cache.warn(
"`training_step` hook signature has changed in v1.3."
" `optimizer_idx` argument has been removed in case of manual optimization. Support for"
" the old signature will be removed in v1.5", DeprecationWarning
)
args.append(opt_idx)
else:
num_opts = len(self.trainer.optimizers)
elif not self.trainer.has_arg("training_step", "optimizer_idx") and self.automatic_optimization:
raise ValueError(
f"Your LightningModule defines {num_opts} optimizers but "
f'training_step is missing the "optimizer_idx" argument.'
f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but"
' `training_step` is missing the `optimizer_idx` argument.'
)

# pass hiddens if using tbptt
Expand Down
3 changes: 2 additions & 1 deletion tests/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ def __init__(self):
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx, optimizer_idx=None):
def training_step(self, batch, batch_idx):
opt_1, opt_2 = self.optimizers()

assert isinstance(opt_1, LightningOptimizer)
assert isinstance(opt_2, LightningOptimizer)

Expand Down
23 changes: 23 additions & 0 deletions tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from unittest import mock

import pytest
from torch import optim

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import WandbLogger
Expand Down Expand Up @@ -74,3 +75,25 @@ def test_v1_5_0_running_sanity_check():
trainer = Trainer()
with pytest.deprecated_call(match='has been renamed to `Trainer.sanity_checking`'):
assert not trainer.running_sanity_check


def test_old_training_step_signature_with_opt_idx_manual_opt(tmpdir):

class OldSignatureModel(BoringModel):

def __init__(self):
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx, optimizer_idx):
assert optimizer_idx is not None
return super().training_step(batch, batch_idx)

def configure_optimizers(self):
return [optim.SGD(self.parameters(), lr=1e-2), optim.SGD(self.parameters(), lr=1e-2)]

model = OldSignatureModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)

with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"):
trainer.fit(model)
34 changes: 17 additions & 17 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def __init__(self):
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx, optimizer_idx):
def training_step(self, batch, batch_idx):
# manual
(opt_a, opt_b) = self.optimizers()
opt_a, opt_b = self.optimizers()
loss_1 = self.step(batch[0])

# make sure there are no grads
Expand Down Expand Up @@ -107,9 +107,9 @@ def __init__(self):
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx, optimizer_idx):
def training_step(self, batch, batch_idx):
# manual
(opt_a, opt_b) = self.optimizers()
opt_a, opt_b = self.optimizers()
loss_1 = self.step(batch[0])

# make sure there are no grads
Expand Down Expand Up @@ -176,9 +176,9 @@ def __init__(self):
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx, optimizer_idx):
def training_step(self, batch, batch_idx):
# manual
(opt_a, opt_b) = self.optimizers()
opt_a, opt_b = self.optimizers()
loss_1 = self.step(batch[0])

# make sure there are no grads
Expand Down Expand Up @@ -251,9 +251,9 @@ def __init__(self):
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx, optimizer_idx):
def training_step(self, batch, batch_idx):
# manual
(opt_a, opt_b) = self.optimizers()
opt_a, opt_b = self.optimizers()
loss_1 = self.step(batch[0])

# make sure there are no grads
Expand Down Expand Up @@ -321,9 +321,9 @@ def __init__(self):
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx, optimizer_idx):
def training_step(self, batch, batch_idx):
# manual
(opt_a, opt_b) = self.optimizers()
opt_a, opt_b = self.optimizers()
x = batch[0]

loss_1 = self(x)
Expand Down Expand Up @@ -610,9 +610,9 @@ def on_after_backward(self):
if not (torch.isinf(norm) or torch.isnan(norm)):
assert norm.item() < 100, norm.item()

def training_step(self, batch, batch_idx, optimizer_idx):
def training_step(self, batch, batch_idx):
# manual
(opt_a, opt_b) = self.optimizers()
opt_a, opt_b = self.optimizers()
x = batch[0]

loss_1 = self(x)
Expand Down Expand Up @@ -886,7 +886,7 @@ def __init__(self):
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx, optimizer_idx):
def training_step(self, batch, batch_idx):

# emulate gans training
opt_gen, opt_dis = self.optimizers()
Expand Down Expand Up @@ -981,7 +981,7 @@ def manual_sync_grad(self) -> bool:
torch_distrib.all_reduce(self.layer.weight.grad.data, async_op=False)
return True

def training_step(self, batch, batch_idx, optimizer_idx):
def training_step(self, batch, batch_idx):

# emulate gans training
opt_gen, opt_dis = self.optimizers()
Expand Down Expand Up @@ -1088,9 +1088,9 @@ def test_step_with_optimizer_closure_with_different_frequencies_ddp_spawn(tmpdir
train_manual_optimization(tmpdir, "ddp_spawn")


class TesManualOptimizationDDPModelToggleModel(TesManualOptimizationDDPModel):
class TestManualOptimizationDDPModelToggleModel(TesManualOptimizationDDPModel):

def training_step(self, batch, batch_idx, optimizer_idx):
def training_step(self, batch, batch_idx):

# emulate gans training
opt_gen, opt_dis = self.optimizers()
Expand Down Expand Up @@ -1147,4 +1147,4 @@ def dis_closure():

@RunIf(min_gpus=2, special=True)
def test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model(tmpdir):
train_manual_optimization(tmpdir, "ddp", model_cls=TesManualOptimizationDDPModelToggleModel)
train_manual_optimization(tmpdir, "ddp", model_cls=TestManualOptimizationDDPModelToggleModel)
25 changes: 19 additions & 6 deletions tests/trainer/optimization/test_multiple_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""
Tests to ensure that the behaviours related to multiple optimizers works
"""
import pytest
import torch

import pytorch_lightning as pl
Expand Down Expand Up @@ -90,11 +91,6 @@ def training_epoch_end(self, outputs) -> None:
# outputs should be an array with an entry per optimizer
assert len(outputs) == 2

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer, optimizer_2

model = TestModel()
model.val_dataloader = None

Expand All @@ -119,7 +115,7 @@ def __init__(self):
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx, optimizer_idx):
def training_step(self, batch, batch_idx):
self.training_step_called = True

# manual optimization
Expand Down Expand Up @@ -154,3 +150,20 @@ def training_epoch_end(self, outputs) -> None:
trainer.fit(model)

assert model.training_step_called


def test_multiple_optimizers_no_opt_idx_argument(tmpdir):
"""
Test that an error is raised if no optimizer_idx is present when
multiple optimizeres are passed in case of automatic_optimization
"""

class TestModel(MultiOptModel):

def training_step(self, batch, batch_idx):
return super().training_step(batch, batch_idx)

trainer = pl.Trainer(default_root_dir=tmpdir, fast_dev_run=2)

with pytest.raises(ValueError, match='`training_step` is missing the `optimizer_idx`'):
trainer.fit(TestModel())