Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

External callback registry through entry points for Fabric #17756

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for loading a full-state checkpoint file into a sharded model ([#17623](https://github.com/Lightning-AI/lightning/pull/17623))


- Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756))


### Changed

- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))
Expand Down
11 changes: 9 additions & 2 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
has_iterable_dataset,
)
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
from lightning.fabric.utilities.registry import _load_external_callbacks
from lightning.fabric.utilities.seed import seed_everything
from lightning.fabric.utilities.types import ReduceOp
from lightning.fabric.utilities.warnings import PossibleUserWarning
Expand Down Expand Up @@ -111,8 +112,7 @@ def __init__(
self._strategy: Strategy = self._connector.strategy
self._accelerator: Accelerator = self._connector.accelerator
self._precision: Precision = self._strategy.precision
callbacks = callbacks if callbacks is not None else []
self._callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
self._callbacks = self._configure_callbacks(callbacks)
loggers = loggers if loggers is not None else []
self._loggers = loggers if isinstance(loggers, list) else [loggers]
self._models_setup: int = 0
Expand Down Expand Up @@ -908,6 +908,13 @@ def _validate_setup_dataloaders(self, dataloaders: Sequence[DataLoader]) -> None
if any(not isinstance(dl, DataLoader) for dl in dataloaders):
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")

@staticmethod
def _configure_callbacks(callbacks: Optional[Union[List[Any], Any]]) -> List[Any]:
callbacks = callbacks if callbacks is not None else []
callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
callbacks.extend(_load_external_callbacks("lightning.fabric.callbacks_factory"))
return callbacks


@contextmanager
def _old_sharded_model_context(strategy: Strategy) -> Generator:
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0")
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True)
_TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1

_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
44 changes: 43 additions & 1 deletion src/lightning/fabric/utilities/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any
import logging
from typing import Any, List, Union

from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0

_log = logging.getLogger(__name__)


def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> bool:
Expand All @@ -25,3 +30,40 @@ def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> boo
return False

return mod_attr.__code__ is not super_attr.__code__


def _load_external_callbacks(group: str) -> List[Any]:
"""Collect external callbacks registered through entry points.

The entry points are expected to be functions returning a list of callbacks.

Args:
group: The entry point group name to load callbacks from.

Return:
A list of all callbacks collected from external factories.
"""
if _PYTHON_GREATER_EQUAL_3_8_0:
from importlib.metadata import entry_points

factories = (
entry_points(group=group)
if _PYTHON_GREATER_EQUAL_3_10_0
else entry_points().get(group, {}) # type: ignore[arg-type]
)
else:
from pkg_resources import iter_entry_points

factories = iter_entry_points(group) # type: ignore[assignment]

external_callbacks: List[Any] = []
for factory in factories:
callback_factory = factory.load()
callbacks_list: Union[List[Any], Any] = callback_factory()
callbacks_list = [callbacks_list] if not isinstance(callbacks_list, list) else callbacks_list
_log.info(
f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
f" {', '.join(type(cb).__name__ for cb in callbacks_list)}"
)
external_callbacks.extend(callbacks_list)
return external_callbacks
40 changes: 2 additions & 38 deletions src/lightning/pytorch/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Dict, List, Optional, Sequence, Union

import lightning.pytorch as pl
from lightning.fabric.utilities.registry import _load_external_callbacks
from lightning.pytorch.callbacks import (
Callback,
Checkpoint,
Expand All @@ -33,7 +34,6 @@
from lightning.pytorch.callbacks.timer import Timer
from lightning.pytorch.trainer import call
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
lightningforever marked this conversation as resolved.
Show resolved Hide resolved
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_info

Expand Down Expand Up @@ -75,7 +75,7 @@ def on_trainer_init(
# configure the ModelSummary callback
self._configure_model_summary_callback(enable_model_summary)

self.trainer.callbacks.extend(_configure_external_callbacks())
self.trainer.callbacks.extend(_load_external_callbacks("lightning.pytorch.callbacks_factory"))
_validate_callbacks_list(self.trainer.callbacks)

# push all model checkpoint callbacks to the end
Expand Down Expand Up @@ -213,42 +213,6 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
return tuner_callbacks + other_callbacks + checkpoint_callbacks


def _configure_external_callbacks() -> List[Callback]:
"""Collect external callbacks registered through entry points.

The entry points are expected to be functions returning a list of callbacks.

Return:
A list of all callbacks collected from external factories.
"""
group = "lightning.pytorch.callbacks_factory"

if _PYTHON_GREATER_EQUAL_3_8_0:
from importlib.metadata import entry_points

factories = (
entry_points(group=group)
if _PYTHON_GREATER_EQUAL_3_10_0
else entry_points().get(group, {}) # type: ignore[arg-type]
)
else:
from pkg_resources import iter_entry_points

factories = iter_entry_points(group) # type: ignore[assignment]

external_callbacks: List[Callback] = []
for factory in factories:
callback_factory = factory.load()
callbacks_list: Union[List[Callback], Callback] = callback_factory()
callbacks_list = [callbacks_list] if isinstance(callbacks_list, Callback) else callbacks_list
_log.info(
f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
f" {', '.join(type(cb).__name__ for cb in callbacks_list)}"
)
external_callbacks.extend(callbacks_list)
return external_callbacks


def _validate_callbacks_list(callbacks: List[Callback]) -> None:
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)]
seen_callbacks = set()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@

import lightning.pytorch as pl
from lightning.fabric.plugins.environments import SLURMEnvironment
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from lightning.fabric.utilities.imports import _IS_WINDOWS, _PYTHON_GREATER_EQUAL_3_8_0
from lightning.pytorch.utilities.rank_zero import rank_zero_info

# copied from signal.pyi
Expand Down
2 changes: 0 additions & 2 deletions src/lightning/pytorch/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import torch
from lightning_utilities.core.imports import package_available, RequirementCache

_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task
Expand Down
64 changes: 64 additions & 0 deletions tests/tests_fabric/utilities/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import contextlib
from unittest import mock
from unittest.mock import Mock

from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
from lightning.fabric.utilities.registry import _load_external_callbacks


class ExternalCallback:
"""A callback in another library that gets registered through entry points."""

pass


def test_load_external_callbacks():
"""Test that the connector collects Callback instances from factories registered through entry points."""

def factory_no_callback():
return []

def factory_one_callback():
return ExternalCallback()

def factory_one_callback_list():
return [ExternalCallback()]

def factory_multiple_callbacks_list():
return [ExternalCallback(), ExternalCallback()]

with _make_entry_point_query_mock(factory_no_callback):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert callbacks == []

with _make_entry_point_query_mock(factory_one_callback):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert isinstance(callbacks[0], ExternalCallback)

with _make_entry_point_query_mock(factory_one_callback_list):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert isinstance(callbacks[0], ExternalCallback)

with _make_entry_point_query_mock(factory_multiple_callbacks_list):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert isinstance(callbacks[0], ExternalCallback)
assert isinstance(callbacks[1], ExternalCallback)


@contextlib.contextmanager
def _make_entry_point_query_mock(callback_factory):
query_mock = Mock()
entry_point = Mock()
entry_point.name = "mocked"
entry_point.load.return_value = callback_factory
if _PYTHON_GREATER_EQUAL_3_10_0:
query_mock.return_value = [entry_point]
import_path = "importlib.metadata.entry_points"
elif _PYTHON_GREATER_EQUAL_3_8_0:
query_mock().get.return_value = [entry_point]
import_path = "importlib.metadata.entry_points"
else:
query_mock.return_value = [entry_point]
import_path = "pkg_resources.iter_entry_points"
with mock.patch(import_path, query_mock):
yield
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest
import torch

from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
from lightning.pytorch import Callback, LightningModule, Trainer
from lightning.pytorch.callbacks import (
EarlyStopping,
Expand All @@ -32,7 +33,6 @@
from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0


def test_checkpoint_callbacks_are_last(tmpdir):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
import torch
from torch import Tensor

from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from lightning.pytorch import callbacks, Trainer
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.loops import _EvaluationLoop
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from tests_pytorch.helpers.runif import RunIf

if _RICH_AVAILABLE:
Expand Down