diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index cd794112305ad..e664e36aa0a63 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -716,6 +716,10 @@ Furthermore, you can register your own optimizers and/or learning rate scheduler ... + # register all `Optimizer` subclasses from the `torch.optim` package + # This is done automatically! + OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer) + cli = LightningCLI(...) .. code-block:: bash diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 7b73b97baf1cd..a81f7ef4fa59c 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -56,7 +56,7 @@ def __call__(self, cls: Type, key: Optional[str] = None, override: bool = False) raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.") self[key] = cls - def register_package(self, module: ModuleType, base_cls: Type) -> None: + def register_classes(self, module: ModuleType, base_cls: Type) -> None: """This function is an utility to register all classes from a module.""" for _, cls in inspect.getmembers(module, predicate=inspect.isclass): if issubclass(cls, base_cls) and cls != base_cls: @@ -77,10 +77,10 @@ def __str__(self) -> str: OPTIMIZER_REGISTRY = _Registry() -OPTIMIZER_REGISTRY.register_package(torch.optim, Optimizer) +OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer) LR_SCHEDULER_REGISTRY = _Registry() -LR_SCHEDULER_REGISTRY.register_package(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) +LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) class LightningArgumentParser(ArgumentParser):