diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 043ae2aab61e3..19ba5a1d11f34 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -7,6 +7,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [UnReleased] - 2023-04-DD +- Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756)) + + ### Changed - diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index e8d9297a5def1..1167a92358d8c 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -43,6 +43,7 @@ has_iterable_dataset, ) from lightning.fabric.utilities.distributed import DistributedSamplerWrapper +from lightning.fabric.utilities.registry import _load_external_callbacks from lightning.fabric.utilities.seed import seed_everything from lightning.fabric.utilities.types import ReduceOp from lightning.fabric.utilities.warnings import PossibleUserWarning @@ -105,8 +106,7 @@ def __init__( self._strategy: Strategy = self._connector.strategy self._accelerator: Accelerator = self._connector.accelerator self._precision: Precision = self._strategy.precision - callbacks = callbacks if callbacks is not None else [] - self._callbacks = callbacks if isinstance(callbacks, list) else [callbacks] + self._callbacks = self._configure_callbacks(callbacks) loggers = loggers if loggers is not None else [] self._loggers = loggers if isinstance(loggers, list) else [loggers] self._models_setup: int = 0 @@ -846,6 +846,13 @@ def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None: if any(not isinstance(dl, DataLoader) for dl in dataloaders): raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.") + @staticmethod + def _configure_callbacks(callbacks: Optional[Union[List[Any], Any]]) -> List[Any]: + callbacks = callbacks if callbacks is not None else [] + callbacks = callbacks if isinstance(callbacks, list) else [callbacks] + callbacks.extend(_load_external_callbacks("lightning.fabric.callbacks_factory")) + return callbacks + def _is_using_cli() -> bool: return bool(int(os.environ.get("LT_CLI_USED", "0"))) diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index bdf3d6f3b34c6..6c63ecaa1eeb1 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -30,3 +30,6 @@ _TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0", use_base_version=True) _TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True) _TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1 + +_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8) +_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/src/lightning/fabric/utilities/registry.py b/src/lightning/fabric/utilities/registry.py index 10903d039ec7d..4c3c96dc5803e 100644 --- a/src/lightning/fabric/utilities/registry.py +++ b/src/lightning/fabric/utilities/registry.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import Any +import logging +from typing import Any, List, Union + +from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0 + +_log = logging.getLogger(__name__) def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> bool: @@ -25,3 +30,40 @@ def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> boo return False return mod_attr.__code__ is not super_attr.__code__ + + +def _load_external_callbacks(group: str) -> List[Any]: + """Collect external callbacks registered through entry points. + + The entry points are expected to be functions returning a list of callbacks. + + Args: + group: The entry point group name to load callbacks from. + + Return: + A list of all callbacks collected from external factories. + """ + if _PYTHON_GREATER_EQUAL_3_8_0: + from importlib.metadata import entry_points + + factories = ( + entry_points(group=group) + if _PYTHON_GREATER_EQUAL_3_10_0 + else entry_points().get(group, {}) # type: ignore[arg-type] + ) + else: + from pkg_resources import iter_entry_points + + factories = iter_entry_points(group) # type: ignore[assignment] + + external_callbacks: List[Any] = [] + for factory in factories: + callback_factory = factory.load() + callbacks_list: Union[List[Any], Any] = callback_factory() + callbacks_list = [callbacks_list] if not isinstance(callbacks_list, list) else callbacks_list + _log.info( + f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':" + f" {', '.join(type(cb).__name__ for cb in callbacks_list)}" + ) + external_callbacks.extend(callbacks_list) + return external_callbacks diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py index 8195a5c4a3b52..d649755172658 100644 --- a/src/lightning/pytorch/trainer/connectors/callback_connector.py +++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py @@ -18,6 +18,7 @@ from typing import Dict, List, Optional, Sequence, Union import lightning.pytorch as pl +from lightning.fabric.utilities.registry import _load_external_callbacks from lightning.pytorch.callbacks import ( Callback, Checkpoint, @@ -33,7 +34,6 @@ from lightning.pytorch.callbacks.timer import Timer from lightning.pytorch.trainer import call from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0 from lightning.pytorch.utilities.model_helpers import is_overridden from lightning.pytorch.utilities.rank_zero import rank_zero_info @@ -75,7 +75,7 @@ def on_trainer_init( # configure the ModelSummary callback self._configure_model_summary_callback(enable_model_summary) - self.trainer.callbacks.extend(_configure_external_callbacks()) + self.trainer.callbacks.extend(_load_external_callbacks("lightning.pytorch.callbacks_factory")) _validate_callbacks_list(self.trainer.callbacks) # push all model checkpoint callbacks to the end @@ -213,42 +213,6 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: return tuner_callbacks + other_callbacks + checkpoint_callbacks -def _configure_external_callbacks() -> List[Callback]: - """Collect external callbacks registered through entry points. - - The entry points are expected to be functions returning a list of callbacks. - - Return: - A list of all callbacks collected from external factories. - """ - group = "lightning.pytorch.callbacks_factory" - - if _PYTHON_GREATER_EQUAL_3_8_0: - from importlib.metadata import entry_points - - factories = ( - entry_points(group=group) - if _PYTHON_GREATER_EQUAL_3_10_0 - else entry_points().get(group, {}) # type: ignore[arg-type] - ) - else: - from pkg_resources import iter_entry_points - - factories = iter_entry_points(group) # type: ignore[assignment] - - external_callbacks: List[Callback] = [] - for factory in factories: - callback_factory = factory.load() - callbacks_list: Union[List[Callback], Callback] = callback_factory() - callbacks_list = [callbacks_list] if isinstance(callbacks_list, Callback) else callbacks_list - _log.info( - f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':" - f" {', '.join(type(cb).__name__ for cb in callbacks_list)}" - ) - external_callbacks.extend(callbacks_list) - return external_callbacks - - def _validate_callbacks_list(callbacks: List[Callback]) -> None: stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)] seen_callbacks = set() diff --git a/src/lightning/pytorch/trainer/connectors/signal_connector.py b/src/lightning/pytorch/trainer/connectors/signal_connector.py index 1f48386e3ff86..7e6b7cd0c5e91 100644 --- a/src/lightning/pytorch/trainer/connectors/signal_connector.py +++ b/src/lightning/pytorch/trainer/connectors/signal_connector.py @@ -11,8 +11,7 @@ import lightning.pytorch as pl from lightning.fabric.plugins.environments import SLURMEnvironment -from lightning.fabric.utilities.imports import _IS_WINDOWS -from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0 +from lightning.fabric.utilities.imports import _IS_WINDOWS, _PYTHON_GREATER_EQUAL_3_8_0 from lightning.pytorch.utilities.rank_zero import rank_zero_info # copied from signal.pyi diff --git a/src/lightning/pytorch/utilities/imports.py b/src/lightning/pytorch/utilities/imports.py index 259f18070362b..bfb1eeb5c5174 100644 --- a/src/lightning/pytorch/utilities/imports.py +++ b/src/lightning/pytorch/utilities/imports.py @@ -17,8 +17,6 @@ import torch from lightning_utilities.core.imports import package_available, RequirementCache -_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8) -_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) _PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11) _TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1") _TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task diff --git a/tests/tests_fabric/utilities/test_registry.py b/tests/tests_fabric/utilities/test_registry.py new file mode 100644 index 0000000000000..75e6e12f5abff --- /dev/null +++ b/tests/tests_fabric/utilities/test_registry.py @@ -0,0 +1,64 @@ +import contextlib +from unittest import mock +from unittest.mock import Mock + +from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0 +from lightning.fabric.utilities.registry import _load_external_callbacks + + +class ExternalCallback: + """A callback in another library that gets registered through entry points.""" + + pass + + +def test_load_external_callbacks(): + """Test that the connector collects Callback instances from factories registered through entry points.""" + + def factory_no_callback(): + return [] + + def factory_one_callback(): + return ExternalCallback() + + def factory_one_callback_list(): + return [ExternalCallback()] + + def factory_multiple_callbacks_list(): + return [ExternalCallback(), ExternalCallback()] + + with _make_entry_point_query_mock(factory_no_callback): + callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory") + assert callbacks == [] + + with _make_entry_point_query_mock(factory_one_callback): + callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory") + assert isinstance(callbacks[0], ExternalCallback) + + with _make_entry_point_query_mock(factory_one_callback_list): + callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory") + assert isinstance(callbacks[0], ExternalCallback) + + with _make_entry_point_query_mock(factory_multiple_callbacks_list): + callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory") + assert isinstance(callbacks[0], ExternalCallback) + assert isinstance(callbacks[1], ExternalCallback) + + +@contextlib.contextmanager +def _make_entry_point_query_mock(callback_factory): + query_mock = Mock() + entry_point = Mock() + entry_point.name = "mocked" + entry_point.load.return_value = callback_factory + if _PYTHON_GREATER_EQUAL_3_10_0: + query_mock.return_value = [entry_point] + import_path = "importlib.metadata.entry_points" + elif _PYTHON_GREATER_EQUAL_3_8_0: + query_mock().get.return_value = [entry_point] + import_path = "importlib.metadata.entry_points" + else: + query_mock.return_value = [entry_point] + import_path = "pkg_resources.iter_entry_points" + with mock.patch(import_path, query_mock): + yield diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py index 71262df9179e8..58f59ad760763 100644 --- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py @@ -19,6 +19,7 @@ import pytest import torch +from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0 from lightning.pytorch import Callback, LightningModule, Trainer from lightning.pytorch.callbacks import ( EarlyStopping, @@ -32,7 +33,6 @@ from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector -from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0 def test_checkpoint_callbacks_are_last(tmpdir): diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py index 36f6356f995cb..cea40b921e1a5 100644 --- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py @@ -25,6 +25,7 @@ import torch from torch import Tensor +from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0 from lightning.pytorch import callbacks, Trainer from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset @@ -32,7 +33,6 @@ from lightning.pytorch.loops import _EvaluationLoop from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0 from tests_pytorch.helpers.runif import RunIf if _RICH_AVAILABLE: