Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callback collection through entry points #12739

Merged
merged 46 commits into from
May 3, 2022
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
d139467
fix version comparison for python version
awaelchli Apr 14, 2022
7c3f93e
add doctest
awaelchli Apr 14, 2022
99d7eaf
load callback factories from entry point
awaelchli Apr 12, 2022
a3b4c3a
docs for entry points
awaelchli Apr 12, 2022
15d2352
formatting
awaelchli Apr 12, 2022
a4bc246
docs
awaelchli Apr 12, 2022
a6c9569
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2022
3eb6253
handle empty entry points
awaelchli Apr 12, 2022
088e7f3
add log
awaelchli Apr 13, 2022
d44a7de
wip test
awaelchli Apr 14, 2022
8735d4c
wip
awaelchli Apr 14, 2022
785da7d
tests
awaelchli Apr 14, 2022
af6f15e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2022
62652f2
fix import
awaelchli Apr 14, 2022
e903737
Merge remote-tracking branch 'origin/feature/callback-entry-points' i…
awaelchli Apr 14, 2022
f085185
Merge branch 'master' into feature/callback-entry-points
awaelchli Apr 14, 2022
a80b0b5
reset
awaelchli Apr 14, 2022
a248e05
update changelog
awaelchli Apr 14, 2022
85ef5c3
Merge branch 'master' into feature/callback-entry-points
awaelchli Apr 14, 2022
3c7aa32
Merge branch 'master' into feature/callback-entry-points
awaelchli Apr 21, 2022
d10be0d
update title level
awaelchli Apr 21, 2022
8dae38f
extract docs
awaelchli Apr 21, 2022
22fafc6
Merge branch 'master' into feature/callback-entry-points
awaelchli Apr 21, 2022
1942933
Update CHANGELOG.md
awaelchli Apr 21, 2022
caf5cac
Update docs/source/extensions/entry_points.rst
awaelchli Apr 21, 2022
8dbfd3b
Update docs/source/extensions/entry_points.rst
awaelchli Apr 21, 2022
e51f95f
Update docs/source/extensions/entry_points.rst
awaelchli Apr 21, 2022
1ac8b55
Merge remote-tracking branch 'origin/feature/callback-entry-points' i…
awaelchli Apr 21, 2022
8418a80
import locally
awaelchli Apr 21, 2022
b5be0bc
split double call
awaelchli Apr 21, 2022
614642a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 21, 2022
8d08724
add uninstallation guide
awaelchli Apr 21, 2022
219025d
Revert "import locally"
awaelchli Apr 21, 2022
661e8d0
fix test
awaelchli Apr 21, 2022
da04388
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 21, 2022
dbb40b3
support Callback as direct return value
awaelchli Apr 21, 2022
049b68d
fix typo in error message
awaelchli Apr 21, 2022
419e995
revert error handling
awaelchli Apr 21, 2022
7cee194
move method to function
awaelchli Apr 22, 2022
f2ea314
Import locally
carmocca Apr 22, 2022
79be51c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2022
b95406b
Remove monkeypatch
carmocca Apr 22, 2022
81e8c4a
Update tests/trainer/connectors/test_callback_connector.py
carmocca Apr 22, 2022
48422ea
Merge branch 'master' into feature/callback-entry-points
carmocca May 3, 2022
7ce5601
Apply suggestions from code review
carmocca May 3, 2022
1602f82
Merge branch 'master' into feature/callback-entry-points
carmocca May 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Include the `pytorch_lightning` version as a header in the CLI config files ([#12532](https://github.com/PyTorchLightning/pytorch-lightning/pull/12532))


-
- Added support for `Callback` registration through entry points ([#12739](https://github.com/PyTorchLightning/pytorch-lightning/pull/12739))


-
Expand Down
6 changes: 6 additions & 0 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ Lightning has a few built-in callbacks.

----------


**************
Best Practices
**************
Expand All @@ -121,6 +122,11 @@ The following are best practices when using/designing callbacks.
4. Directly calling methods (eg. `on_validation_end`) is strongly discouraged.
5. Whenever possible, your callbacks should not depend on the order in which they are executed.


-----------

.. include:: entry_points.rst

-----------

.. _callback_hooks:
Expand Down
44 changes: 44 additions & 0 deletions docs/source/extensions/entry_points.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
************
Entry Points
************

Lightning supports registering Trainer callbacks directly through
`Entry Points <https://setuptools.pypa.io/en/latest/userguide/entry_point.html>`_. Entry points allow an arbitrary
package to include callbacks that the Lightning Trainer can automatically use, without you having to add them
to the Trainer manually. This is useful in production environments where it is common to provide specialized monitoring
and logging callbacks globally for every application.

Here is a callback factory function that returns two special callbacks:

.. code-block:: python
:caption: factories.py

def my_custom_callbacks_factory():
return [MyCallback1(), MyCallback2()]

If we make this `factories.py` file into an installable package, we can define an **entry point** for this factory function.
Here is a minimal example of the `setup.py` file for the package `my-package`:

.. code-block:: python
:caption: setup.py

from setuptools import setup

setup(
name="my-package",
version="0.0.1",
entry_points={
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"pytorch_lightning.callbacks_factory": [
# The format here must be [any name]=[module path]:[function name]
"monitor_callbacks=factories:my_custom_callbacks_factory"
]
},
)

The group name for the entry points is ``pytorch_lightning.callbacks_factory`` and it contains a list of strings that
specify where to find the function within the package.

Now, if you `pip install -e .` this package, it will register the ``my_custom_callbacks_factory`` function and Lightning
will automatically call it to collect the callbacks whenever you run the Trainer!
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

To unregister the factory, simply uninstall the package with `pip uninstall .` inside the package.
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
carmocca marked this conversation as resolved.
Show resolved Hide resolved
43 changes: 43 additions & 0 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from datetime import timedelta
from typing import Dict, List, Optional, Sequence, Union
Expand All @@ -28,8 +30,17 @@
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.utilities.enums import ModelSummaryMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info

if _PYTHON_GREATER_EQUAL_3_8_0:
import importlib.metadata
else:
from pkg_resources import iter_entry_points


_log = logging.getLogger(__name__)


class CallbackConnector:
def __init__(self, trainer):
Expand Down Expand Up @@ -91,6 +102,8 @@ def on_trainer_init(
if self.trainer.state._fault_tolerant_mode.is_enabled:
self._configure_fault_tolerance_callbacks()

self._configure_external_callbacks()

# push all model checkpoint callbacks to the end
# it is important that these are the last callbacks to run
self.trainer.callbacks = self._reorder_callbacks(self.trainer.callbacks)
Expand Down Expand Up @@ -232,6 +245,36 @@ def _configure_fault_tolerance_callbacks(self):
# don't use `log_dir` to minimize the chances of failure
self.trainer.callbacks.append(_FaultToleranceCheckpoint(dirpath=self.trainer.default_root_dir))

def _configure_external_callbacks(self) -> None:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""Add external callbacks registered through entry points.

The entry points are expected to be functions returning a list of callbacks, which will be added to the Trainer
callback list.
"""
if _PYTHON_GREATER_EQUAL_3_8_0:
factories = importlib.metadata.entry_points().get("pytorch_lightning.callbacks_factory", ())
carmocca marked this conversation as resolved.
Show resolved Hide resolved
else:
factories = iter_entry_points("pytorch_lightning.callbacks_factory")

for factory in factories:
callback_factory = factory.load()
callbacks_list: List[Callback] = callback_factory()
if not isinstance(callbacks_list, list):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(
f"The entry point '{factory.name}' returned a {type(callbacks_list)} but is expected to return"
f" a list of `pytorch_lightning.callbacks.Callback`."
)
if not all(isinstance(cb, Callback) for cb in callbacks_list):
raise TypeError(
f"The entry point '{factory.name}' is expected to return a list of callbacks, but at least one"
" callack was not an instance of `pytorch_lightning.callbacks.Callback`."
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
_log.info(
f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
f" {', '.join(type(cb).__name__ for cb in callbacks_list)}"
)
self.trainer.callbacks.extend(callbacks_list)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def _attach_model_logging_functions(self):
lightning_module = self.trainer.lightning_module
for callback in self.trainer.callbacks:
Expand Down
70 changes: 70 additions & 0 deletions tests/trainer/connectors/test_callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from unittest.mock import Mock

import pytest
import torch

import pytorch_lightning
carmocca marked this conversation as resolved.
Show resolved Hide resolved
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.callbacks import (
EarlyStopping,
Expand All @@ -26,6 +29,7 @@
TQDMProgressBar,
)
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from tests.helpers import BoringModel


Expand Down Expand Up @@ -214,3 +218,69 @@ def test_attach_model_callbacks_override_info(caplog):
cb_connector._attach_model_callbacks()

assert "existing callbacks passed to Trainer: EarlyStopping, LearningRateMonitor" in caplog.text


class ExternalCallback(Callback):
"""A callback in another library that gets registered through entry points."""

pass


def test_configure_external_callbacks_raises(monkeypatch):
"""Test that the connector validates the return type of Callback factories registered through entry points."""

def factory_incorrect_return_type():
return "invalid"

def factory_incorrect_element_type():
return [ExternalCallback(), "invalid"]

_make_entry_point_query_mock(monkeypatch, factory_incorrect_return_type)
with pytest.raises(TypeError, match="The entry point 'mocked' returned a <class 'str'> but is expected to return"):
Trainer()

_make_entry_point_query_mock(monkeypatch, factory_incorrect_element_type)
with pytest.raises(TypeError, match="at least one callack was not an instance of"):
Trainer()


def test_configure_external_callbacks(monkeypatch):
"""Test that the connector collects Callback instances from factories registered through entry points."""

def factory_no_callback():
return []

def factory_one_callback():
return [ExternalCallback()]

def factory_multiple_callbacks():
return [ExternalCallback(), ExternalCallback()]

_make_entry_point_query_mock(monkeypatch, factory_no_callback)
trainer = Trainer(enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False)
assert trainer.callbacks == [trainer.accumulation_scheduler] # this scheduler callback gets added by default

_make_entry_point_query_mock(monkeypatch, factory_one_callback)
trainer = Trainer(enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False)
assert isinstance(trainer.callbacks[1], ExternalCallback)

_make_entry_point_query_mock(monkeypatch, factory_multiple_callbacks)
trainer = Trainer(enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False)
assert isinstance(trainer.callbacks[1], ExternalCallback)
assert isinstance(trainer.callbacks[2], ExternalCallback)


def _make_entry_point_query_mock(monkeypatch, callback_factory):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
query_mock = Mock()
entry_point = Mock()
entry_point.name = "mocked"
entry_point.load.return_value = callback_factory
if _PYTHON_GREATER_EQUAL_3_8_0:
query_mock().get.return_value = [entry_point]
monkeypatch.setattr(
pytorch_lightning.trainer.connectors.callback_connector.importlib.metadata, "entry_points", query_mock
)
return query_mock.get
else:
query_mock.return_value = [entry_point]
monkeypatch.setattr(pytorch_lightning.trainer.connectors.callback_connector, "iter_entry_points", query_mock)