Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple reproducibility with minimum boilerplate train_cli #4492

Merged
merged 46 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
b04b515
Added new trainer_cli that reduces boilerplate
mauvilsa Nov 3, 2020
186b6d8
- Converted trainer_cli function into a class that can be more easily…
mauvilsa Nov 18, 2020
c7cf9aa
Merge branch 'master' into config_files_argparse
mauvilsa Nov 18, 2020
d286f49
- Fixes required by pep8speaks in trainer_cli.py.
mauvilsa Nov 18, 2020
ed64f31
Renamed class to LightningCLI and other minor fixes
mauvilsa Nov 19, 2020
35d3aa3
- Fixed bug in testcode of trainer_cli.rst.
mauvilsa Nov 30, 2020
6a4ae27
- Renamed files to reflect new class name LightningCLI.
mauvilsa Dec 2, 2020
6f0f2a0
Work on LightningCLI:
mauvilsa Jan 8, 2021
69794a1
Merge branch 'release/1.2-dev' into config_files_argparse
mauvilsa Jan 8, 2021
e91a351
Work on LightningCLI:
mauvilsa Jan 8, 2021
7a76ea6
Work on LightningCLI:
mauvilsa Jan 8, 2021
03be5c4
Swap instantiation of datamodule and model in LightningCLI.
mauvilsa Jan 11, 2021
69904d7
Changed LightningArgumentParser add args methods to a single one add_…
mauvilsa Jan 12, 2021
a726c3f
Made pytorch_lightning.utilities.cli importable even when jsonargpars…
mauvilsa Jan 15, 2021
e498c33
Merge branch 'release/1.2-dev' into config_files_argparse
mauvilsa Jan 15, 2021
2cd09d5
- Fix "Check valid import formatting with isort".
mauvilsa Jan 15, 2021
04fdd5a
Fix "Check valid import formatting with isort"
mauvilsa Jan 15, 2021
a4a4a55
Fix "Check valid import formatting with isort"
mauvilsa Jan 15, 2021
4d2e796
Merge branch 'release/1.2-dev' into config_files_argparse
Borda Jan 19, 2021
1403501
Fix "Check valid import formatting with isort"
mauvilsa Jan 21, 2021
dcbd3d7
Fix "Check valid import formatting with isort"
mauvilsa Jan 21, 2021
2635775
Merge branch 'release/1.2-dev' into config_files_argparse
mauvilsa Feb 1, 2021
4d7f5b9
Update to reflect change in structure in docs/sources
mauvilsa Feb 1, 2021
e4f219a
Merge branch 'master' into config_files_argparse
carmocca Feb 15, 2021
a4da435
Update to latest changes
carmocca Feb 16, 2021
f4e0fbb
Better formatting
carmocca Feb 16, 2021
93f5e75
Merge branch 'master' into config_files_argparse
carmocca Feb 21, 2021
36a578b
Merge branch 'master' into config_files_argparse
Borda Mar 1, 2021
e682a80
Merge branch 'master' into config_files_argparse
carmocca Mar 9, 2021
9c4fc7f
- Change wrapping size to 120 in lightning_cli.rst.
mauvilsa Mar 9, 2021
ebc4944
- Fix missing space around operator.
mauvilsa Mar 9, 2021
aa0a129
- Added trainer_defaults callbacks to unit tests
mauvilsa Mar 9, 2021
c9fc74e
Merge branch 'master' into config_files_argparse
carmocca Mar 30, 2021
05e460c
Merge branch 'master' into config_files_argparse
Borda Mar 30, 2021
1750266
Added unit test for LightningCLI with callbacks as argument
mauvilsa Mar 31, 2021
08b3e5c
Added to LightningCLI documentation the need for the 'all' extras req…
mauvilsa Mar 31, 2021
4e93200
Fixed PEP8 issue in test_lightning_cli.py
mauvilsa Mar 31, 2021
bac68b5
Add beta warnings
carmocca Apr 6, 2021
ee0ff21
Rename test file
carmocca Apr 6, 2021
7475b9f
Refactor callback test. Fix missing log_dir
carmocca Apr 6, 2021
c5b0cec
Typing
carmocca Apr 6, 2021
ff43a8b
Move test file to utilities. Use dict constructor which plays nicer w…
carmocca Apr 6, 2021
c2f0b86
Refactor fn
carmocca Apr 6, 2021
35f5d0d
Use trainer.log_dir
carmocca Apr 6, 2021
9562061
Fix docs
carmocca Apr 6, 2021
5604c04
Minor docs change
carmocca Apr 6, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `LightningCLI` class to provide simple reproducibility with minimum boilerplate training cli. ([#4492](https://github.com/PyTorchLightning/pytorch-lightning/pull/4492))


- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))


- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))



### Changed


Expand Down
1 change: 1 addition & 0 deletions docs/source/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,6 @@ Utilities API
:toctree: api
:nosignatures:

cli
argparse_utils
seed
341 changes: 341 additions & 0 deletions docs/source/common/lightning_cli.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,341 @@
.. testsetup:: *
:skipif: not _JSONARGPARSE_AVAILABLE

from unittest import mock
from typing import List
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.utilities.cli import LightningCLI

original_fit = LightningCLI.fit
LightningCLI.fit = lambda self: None

class MyModel(LightningModule):
def __init__(
self,
encoder_layers: int = 12,
decoder_layers: List[int] = [2, 4]
):
"""Example encoder-decoder model

Args:
encoder_layers: Number of layers for the encoder
decoder_layers: Number of layers for each decoder block
"""
pass

class MyDataModule(LightningDataModule):
pass

def send_email(address, message):
pass

MyModelBaseClass = MyModel
MyDataModuleBaseClass = MyDataModule

mock_argv = mock.patch("sys.argv", ["any.py"])
mock_argv.start()

.. testcleanup:: *

LightningCLI.fit = original_fit
mock_argv.stop()


Lightning CLI and config files
------------------------------

Another source of boilerplate code that Lightning can help to reduce is in the
implementation of training command line tools. Furthermore, it provides a
standardized way to configure trainings using a single file that includes
settings for :class:`~pytorch_lightning.trainer.trainer.Trainer` and user
extended :class:`~pytorch_lightning.core.lightning.LightningModule` and
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes. The
full configuration is automatically saved in the log directory. This has the
benefit of greatly simplifying the reproducibility of experiments.

The main requirement for user extended classes to be made configurable is that
all relevant init arguments must have type hints. This is not a very demanding
requirement since it is good practice to do anyway. As a bonus if the arguments
are described in the docstrings, then the help of the training tool will display
them.

----------


LightningCLI
^^^^^^^^^^^^

The case in which the user's
:class:`~pytorch_lightning.core.lightning.LightningModule` class implements all
required :code:`*_dataloader` methods, a :code:`trainer.py` tool can be as
simple as:

.. testcode::

from pytorch_lightning.utilities.cli import LightningCLI

cli = LightningCLI(MyModel)

The help of the tool describing all configurable options and default values can
be shown by running :code:`python trainer.py --help`. Default options can be
changed by providing individual command line arguments. However, it is better
practice to create a configuration file and provide this to the tool. A way
to do this would be:

.. code-block:: bash

# Dump default configuration to have as reference
python trainer.py --print_config > default_config.yaml
# Create config including only options to modify
nano config.yaml
# Run training using created configuration
python trainer.py --config config.yaml

The instantiation of the :class:`~pytorch_lightning.utilities.cli.LightningCLI`
class takes care of parsing command line and config file options, instantiating
the classes, setting up a callback to save the config in the log directory and
finally running :func:`trainer.fit`. The resulting object :code:`cli` can be
used for instance to get the result of fit, i.e., :code:`cli.fit_result`.

After multiple trainings with different configurations, each run will have in
its respective log directory a :code:`config.yaml` file. This file can be used
for reference to know in detail all the settings that were used for each
particular run, and also could be used to trivially reproduce a training, e.g.:

.. code-block:: bash

python trainer.py --config lightning_logs/version_7/config.yaml

If a separate :class:`~pytorch_lightning.core.datamodule.LightningDataModule`
class is required, the trainer tool just needs a small modification as follows:

.. testcode::

from pytorch_lightning.utilities.cli import LightningCLI

cli = LightningCLI(MyModel, MyDataModule)

The start of a possible implementation of :class:`MyModel` including the
recommended argument descriptions in the docstring could be the one below. Note
that by using type hints and docstrings there is no need to duplicate this
information to define its configurable arguments.

.. code-block:: python

class MyModel(LightningModule):

def __init__(
self,
encoder_layers: int = 12,
decoder_layers: List[int] = [2, 4]
):
"""Example encoder-decoder model

Args:
encoder_layers: Number of layers for the encoder
decoder_layers: Number of layers for each decoder block
"""
...

With this model class, the help of the trainer tool would look as follows:

.. code-block:: bash

$ python trainer.py --help
usage: trainer.py [-h] [--print_config] [--config CONFIG]
[--trainer.logger LOGGER]
...

pytorch-lightning trainer command line tool

optional arguments:
-h, --help show this help message and exit
--print_config print configuration and exit
--config CONFIG Path to a configuration file in json or yaml format.
(default: null)

Customize every aspect of training via flags:
...
--trainer.max_epochs MAX_EPOCHS
Stop training once this number of epochs is reached.
(type: int, default: 1000)
--trainer.min_epochs MIN_EPOCHS
Force training for at least these many epochs (type: int,
default: 1)
...

Example encoder-decoder model:
--model.encoder_layers ENCODER_LAYERS
Number of layers for the encoder (type: int, default: 12)
--model.decoder_layers DECODER_LAYERS
Number of layers for each decoder block (type: List[int],
default: [2, 4])

The default configuration that option :code:`--print_config` gives is in yaml
format and for the example above would look as follows:

.. code-block:: bash

$ python trainer.py --print_config
model:
decoder_layers:
- 2
- 4
encoder_layers: 12
trainer:
accelerator: null
accumulate_grad_batches: 1
amp_backend: native
amp_level: O2
...

Note that there is a section for each class (model and trainer) including all
the init parameters of the class. This grouping is also used in the formatting
of the help shown previously.


Trainer Callbacks and arguments with class type
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

A very important argument of the
:class:`~pytorch_lightning.trainer.trainer.Trainer` class is the
:code:`callbacks`. In contrast to other more simple arguments which just require
numbers or strings, :code:`callbacks` expects a list of instances of subclasses
of :class:`~pytorch_lightning.callbacks.Callback`. To specify this kind of
argument in a config file, each callback must be given as a dictionary including
a :code:`class_path` entry with an import path of the class, and optionally an
:code:`init_args` entry with arguments required to instantiate it. Therefore, a
simple configuration file example that defines a couple of callbacks is the
following:

.. code-block:: yaml

trainer:
callbacks:
- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
patience: 5
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
init_args:
...

Similar to the callbacks, any arguments in
:class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended
:class:`~pytorch_lightning.core.lightning.LightningModule` and
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that
have as type hint a class can be configured the same way using
:code:`class_path` and :code:`init_args`.


Multiple models and/or datasets
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In the previous examples :class:`~pytorch_lightning.utilities.cli.LightningCLI`
works only for a single model and datamodule class. However, there are many
cases in which the objective is to easily be able to run many experiments for
multiple models and datasets. For these cases the tool can be configured such
that a model and/or a datamodule is specified by an import path and init
arguments. For example, with a tool implemented as:

.. testcode::

from pytorch_lightning.utilities.cli import LightningCLI

cli = LightningCLI(
MyModelBaseClass,
MyDataModuleBaseClass,
subclass_mode_model=True,
subclass_mode_data=True
)

A possible config file could be as follows:

.. code-block:: yaml

model:
class_path: mycode.mymodels.MyModel
init_args:
decoder_layers:
- 2
- 4
encoder_layers: 12
data:
class_path: mycode.mydatamodules.MyDataModule
init_args:
...
trainer:
callbacks:
- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
patience: 5
...

Only model classes that are a subclass of :code:`MyModelBaseClass` would be
allowed, and similarly only subclasses of :code:`MyDataModuleBaseClass`.


Customizing LightningCLI
^^^^^^^^^^^^^^^^^^^^^^^^

The init parameters of the
:class:`~pytorch_lightning.utilities.cli.LightningCLI` class can be used to
customize some things, namely: the description of the tool, enabling parsing of
environment variables and additional arguments to instantiate the trainer and
configuration parser.

Nevertheless the init arguments are not enough for many use cases. For this
reason the class is designed so that can be extended to customize different
parts of the command line tool. The argument parser class used by
:class:`~pytorch_lightning.utilities.cli.LightningCLI` is
:class:`~pytorch_lightning.utilities.cli.LightningArgumentParser` which is an
extension of python's argparse, thus adding arguments can be done using the
:func:`add_argument` method. In contrast to argparse it has additional methods
to add arguments, for example :func:`add_class_arguments` adds all arguments
from the init of a class, though requiring parameters to have type hints. For
more details about this please refer to the `respective documentation
<https://omni-us.github.io/jsonargparse/#classes-methods-and-functions>`_.

The :class:`~pytorch_lightning.utilities.cli.LightningCLI` class has the
:meth:`~pytorch_lightning.utilities.cli.LightningCLI.add_arguments_to_parser`
method which can be implemented to include more arguments. After parsing, the
configuration is stored in the :code:`config` attribute of the class instance.
The :class:`~pytorch_lightning.utilities.cli.LightningCLI` class also has two
methods that can be used to run code before and after :code:`trainer.fit` is
executed: :meth:`~pytorch_lightning.utilities.cli.LightningCLI.before_fit` and
:meth:`~pytorch_lightning.utilities.cli.LightningCLI.after_fit`. A realistic
example for these would be to send an email before and after the execution of
fit. The code would be something like:

.. testcode::

from pytorch_lightning.utilities.cli import LightningCLI

class MyLightningCLI(LightningCLI):

def add_arguments_to_parser(self, parser):
parser.add_argument('--notification_email', default='[email protected]')

def before_fit(self):
send_email(
address=self.config['notification_email'],
message='trainer.fit starting'
)

def after_fit(self):
send_email(
address=self.config['notification_email'],
message='trainer.fit finished'
)

cli = MyLightningCLI(MyModel)

Note that the config object :code:`self.config` is a dictionary whose keys are
global options or groups of options. It has the same structure as the yaml
format as described previously. This means for instance that the parameters used
for instantiating the trainer class can be found in
:code:`self.config['trainer']`.

For more advanced use cases, other methods of the
:class:`~pytorch_lightning.utilities.cli.LightningCLI` class could be extended.
For further information have a look at the corresponding API reference.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,6 @@ def package_list_from_file(file):
_TORCHVISION_AVAILABLE,
_module_available,
)
TORCHVISION_AVAILABLE = _module_available("torchvision")
_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse")
"""
coverage_skip_undoc_in_source = True
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ PyTorch Lightning Documentation
common/early_stopping
common/fast_training
common/hyperparameters
common/lightning_cli
advanced/lr_finder
advanced/multi_gpu
advanced/multiple_loaders
Expand Down
Loading