Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LightningCLI support for optimizers and schedulers via dependency injection #15869

Merged
merged 6 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 63 additions & 33 deletions docs/source-pytorch/cli/lightning_cli_advanced_3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ If the CLI is implemented as ``LightningCLI(MyMainModel)`` the configuration wou
It is also possible to combine ``subclass_mode_model=True`` and submodules, thereby having two levels of ``class_path``.


Optimizers
^^^^^^^^^^
Fixed optimizer and scheduler
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In some cases, fixing the optimizer and/or learning scheduler might be desired instead of allowing multiple. For this,
you can manually add the arguments for specific classes by subclassing the CLI. The following code snippet shows how to
Expand Down Expand Up @@ -251,58 +251,88 @@ where the arguments can be passed directly through the command line without spec

$ python trainer.py fit --optimizer.lr=0.01 --lr_scheduler.gamma=0.2

The automatic implementation of ``configure_optimizers`` can be disabled by linking the configuration group. An example
can be when someone wants to add support for multiple optimizers:

Multiple optimizers and schedulers
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

By default, the CLIs support multiple optimizers and/or learning schedulers, automatically implementing
``configure_optimizers``. This behavior can be disabled by providing ``auto_configure_optimizers=False`` on
instantiation of :class:`~pytorch_lightning.cli.LightningCLI`. This would be required for example to support multiple
optimizers, for each selecting a particular optimizer class. Similar to multiple submodules, this can be done via
`dependency injection <https://en.wikipedia.org/wiki/Dependency_injection>`__. Unlike the submodules, it is not possible
to expect an instance of a class, because optimizers require the module's parameters to optimize, which are only
available after instantiation of the module. Learning schedulers are a similar situation, requiring an optimizer
instance. For these cases, dependency injection involves providing a function that instantiates the respective class
when called.

An example of a model that uses two optimizers is the following:

.. code-block:: python

from pytorch_lightning.cli import instantiate_class
from typing import Iterable
from torch.optim import Optimizer


OptimizerCallable = Callable[[Iterable], Optimizer]


class MyModel(LightningModule):
def __init__(self, optimizer1_init: dict, optimizer2_init: dict):
def __init__(self, optimizer1: OptimizerCallable, optimizer2: OptimizerCallable):
super().__init__()
self.optimizer1_init = optimizer1_init
self.optimizer2_init = optimizer2_init
self.optimizer1 = optimizer1
self.optimizer2 = optimizer2

def configure_optimizers(self):
optimizer1 = instantiate_class(self.parameters(), self.optimizer1_init)
optimizer2 = instantiate_class(self.parameters(), self.optimizer2_init)
optimizer1 = self.optimizer1(self.parameters())
optimizer2 = self.optimizer2(self.parameters())
return [optimizer1, optimizer2]


class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(nested_key="optimizer1", link_to="model.optimizer1_init")
parser.add_optimizer_args(nested_key="optimizer2", link_to="model.optimizer2_init")
cli = MyLightningCLI(MyModel, auto_configure_optimizers=False)

Note the type ``Callable[[Iterable], Optimizer]``, which denotes a function that receives a singe argument, some
learnable parameters, and returns an optimizer instance. With this, from the command line it is possible to select the
class and init arguments for each of the optimizers, as follows:

cli = MyLightningCLI(MyModel)
.. code-block:: bash

The value given to ``optimizer*_init`` will always be a dictionary including ``class_path`` and ``init_args`` entries.
The function :func:`~pytorch_lightning.cli.instantiate_class` takes care of importing the class defined in
``class_path`` and instantiating it using some positional arguments, in this case ``self.parameters()``, and the
``init_args``. Any number of optimizers and learning rate schedulers can be added when using ``link_to``.
$ python trainer.py fit \
--model.optimizer1=Adam \
--model.optimizer1.lr=0.01 \
--model.optimizer2=AdamW \
--model.optimizer2.lr=0.0001

With shorthand notation:
In the example above, the ``OptimizerCallable`` type alias was created to illustrate what the type hint means. For
convenience, this type alias and one for learning schedulers is available in the ``cli`` module. An example of a model
that uses dependency injection for an optimizer and a learning scheduler is:

.. code-block:: bash
.. code-block:: python

$ python trainer.py fit \
--optimizer1=Adam \
--optimizer1.lr=0.01 \
--optimizer2=AdamW \
--optimizer2.lr=0.0001
from pytorch_lightning.cli import OptimizerCallable, LRSchedulerCallable, LightningCLI
carmocca marked this conversation as resolved.
Show resolved Hide resolved

You can also pass the class path directly, for example, if the optimizer hasn't been imported:

.. code-block:: bash
class MyModel(LightningModule):
def __init__(
self,
optimizer: OptimizerCallable = torch.optim.Adam,
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__()
self.optimizer = optimizer
self.scheduler = scheduler

$ python trainer.py fit \
--optimizer1=torch.optim.Adam \
--optimizer1.lr=0.01 \
--optimizer2=torch.optim.AdamW \
--optimizer2.lr=0.0001
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters())
scheduler = self.scheduler(self.parameters())
return {"optimizer": optimizer, "lr_scheduler": scheduler}


cli = MyLightningCLI(MyModel, auto_configure_optimizers=False)

Note that for this example, classes are used as defaults. This is compatible with the type hints, since they are also
callables that receive the same first argument and return an instance of the class. Classes that have more than one
required argument will not work as default. For these cases a lambda function can be used, e.g. ``optimizer:
OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01)``.


Run from Python
Expand Down
2 changes: 1 addition & 1 deletion requirements/pytorch/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
matplotlib>3.1, <3.6.2
omegaconf>=2.0.5, <2.3.0
hydra-core>=1.0.5, <1.3.0
jsonargparse[signatures]>=4.17.0, <4.18.0
jsonargparse[signatures]>=4.18.0, <4.19.0
rich>=10.14.0, !=10.15.0.a, <13.0.0
4 changes: 4 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814))


- Added `LightningCLI` support for optimizer and learning schedulers via callable type dependency injection ([#15869](https://github.com/Lightning-AI/lightning/pull/15869))


### Changed

- Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347))
Expand Down
25 changes: 17 additions & 8 deletions src/pytorch_lightning/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import sys
from functools import partial, update_wrapper
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union

import torch
from lightning_utilities.core.imports import RequirementCache
Expand Down Expand Up @@ -49,9 +49,6 @@
locals()["Namespace"] = object


ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]]


class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
super().__init__(optimizer, *args, **kwargs)
Expand All @@ -64,6 +61,12 @@ def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[ReduceLROnPlateau]]


# Type aliases intended for convenience of CLI developers
ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]]
OptimizerCallable = Callable[[Iterable], Optimizer]
LRSchedulerCallable = Callable[[Optimizer], Union[torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau]]


class LightningArgumentParser(ArgumentParser):
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""

Expand Down Expand Up @@ -274,6 +277,7 @@ def __init__(
subclass_mode_data: bool = False,
args: ArgsType = None,
run: bool = True,
auto_configure_optimizers: bool = True,
auto_registry: bool = False,
**kwargs: Any, # Remove with deprecations of v1.10
) -> None:
Expand Down Expand Up @@ -326,6 +330,7 @@ def __init__(
self.trainer_defaults = trainer_defaults or {}
self.seed_everything_default = seed_everything_default
self.parser_kwargs = parser_kwargs or {} # type: ignore[var-annotated] # github.com/python/mypy/issues/6463
self.auto_configure_optimizers = auto_configure_optimizers

self._handle_deprecated_params(kwargs)

Expand Down Expand Up @@ -447,10 +452,11 @@ def _add_arguments(self, parser: LightningArgumentParser) -> None:
self.add_core_arguments_to_parser(parser)
self.add_arguments_to_parser(parser)
# add default optimizer args if necessary
if not parser._optimizers: # already added by the user in `add_arguments_to_parser`
parser.add_optimizer_args((Optimizer,))
if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser`
parser.add_lr_scheduler_args(LRSchedulerTypeTuple)
if self.auto_configure_optimizers:
if not parser._optimizers: # already added by the user in `add_arguments_to_parser`
parser.add_optimizer_args((Optimizer,))
if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser`
parser.add_lr_scheduler_args(LRSchedulerTypeTuple)
self.link_optimizers_and_lr_schedulers(parser)

def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
Expand Down Expand Up @@ -602,6 +608,9 @@ def configure_optimizers(
def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None:
"""Overrides the model's :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers` method
if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'."""
if not self.auto_configure_optimizers:
return

parser = self._parser(subcommand)

def get_automatic(
Expand Down
52 changes: 52 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
instantiate_class,
LightningArgumentParser,
LightningCLI,
LRSchedulerCallable,
LRSchedulerTypeTuple,
OptimizerCallable,
SaveConfigCallback,
)
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
Expand Down Expand Up @@ -706,6 +708,56 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict):
assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)


def test_lightning_cli_optimizers_and_lr_scheduler_with_callable_type():
class TestModel(BoringModel):
def __init__(
self,
optim1: OptimizerCallable = torch.optim.Adam,
optim2: OptimizerCallable = torch.optim.Adagrad,
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__()
self.optim1 = optim1
self.optim2 = optim2
self.scheduler = scheduler

def configure_optimizers(self):
optim1 = self.optim1(self.parameters())
optim2 = self.optim2(self.parameters())
scheduler = self.scheduler(optim2)
return (
{"optimizer": optim1},
{"optimizer": optim2, "lr_scheduler": scheduler},
)

out = StringIO()
with mock.patch("sys.argv", ["any.py", "-h"]), redirect_stdout(out), pytest.raises(SystemExit):
LightningCLI(TestModel, run=False, auto_configure_optimizers=False)
out = out.getvalue()
assert "--optimizer" not in out
assert "--lr_scheduler" not in out
assert "--model.optim1" in out
assert "--model.optim2" in out
assert "--model.scheduler" in out

cli_args = [
"--model.optim1=Adagrad",
"--model.optim2=SGD",
"--model.optim2.lr=0.007",
"--model.scheduler=ExponentialLR",
"--model.scheduler.gamma=0.3",
]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(TestModel, run=False, auto_configure_optimizers=False)

init = cli.model.configure_optimizers()
assert isinstance(init[0]["optimizer"], torch.optim.Adagrad)
assert isinstance(init[1]["optimizer"], torch.optim.SGD)
assert isinstance(init[1]["lr_scheduler"], torch.optim.lr_scheduler.ExponentialLR)
assert init[1]["optimizer"].param_groups[0]["lr"] == 0.007
assert init[1]["lr_scheduler"].gamma == 0.3


@pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn])
def test_lightning_cli_trainer_fn(fn):
class TestCLI(LightningCLI):
Expand Down