Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid circular imports when lightning-habana or lightning-graphcore is installed #18226

Merged
merged 4 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/lightning/pytorch/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _LIGHTNING_GRAPHCORE_AVAILABLE
from lightning.pytorch.utilities.imports import _lightning_graphcore_available
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
Expand Down Expand Up @@ -123,7 +123,7 @@ def __verify_batch_transfer_support(trainer: "pl.Trainer") -> None:
datahook_selector = trainer._data_connector._datahook_selector
assert datahook_selector is not None
for hook in batch_transfer_hooks:
if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
from lightning_graphcore import IPUAccelerator

# TODO: This code could be done in a hook in the IPUAccelerator as it's a simple error check
Expand Down
26 changes: 13 additions & 13 deletions src/lightning/pytorch/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@
from lightning.pytorch.utilities.imports import (
_LIGHTNING_BAGUA_AVAILABLE,
_LIGHTNING_COLOSSALAI_AVAILABLE,
_LIGHTNING_GRAPHCORE_AVAILABLE,
_LIGHTNING_HABANA_AVAILABLE,
_lightning_graphcore_available,
_lightning_habana_available,
)
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn

Expand Down Expand Up @@ -338,12 +338,12 @@ def _choose_auto_accelerator(self) -> str:
"""Choose the accelerator type (str) based on availability."""
if XLAAccelerator.is_available():
return "tpu"
if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
from lightning_graphcore import IPUAccelerator

if IPUAccelerator.is_available():
return "ipu"
if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
from lightning_habana import HPUAccelerator

if HPUAccelerator.is_available():
Expand Down Expand Up @@ -411,7 +411,7 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment:

def _choose_strategy(self) -> Union[Strategy, str]:
if self._accelerator_flag == "ipu":
if not _LIGHTNING_GRAPHCORE_AVAILABLE:
if not _lightning_graphcore_available():
raise ImportError(
"You have passed `accelerator='ipu'` but the IPU integration is not installed."
" Please run `pip install lightning-graphcore` or check out"
Expand All @@ -421,7 +421,7 @@ def _choose_strategy(self) -> Union[Strategy, str]:

return IPUStrategy.strategy_name
if self._accelerator_flag == "hpu":
if not _LIGHTNING_HABANA_AVAILABLE:
if not _lightning_habana_available():
raise ImportError(
"You have asked for HPU but you miss install related integration."
" Please run `pip install lightning-habana` or see for further instructions"
Expand Down Expand Up @@ -490,7 +490,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
if isinstance(self._precision_plugin_flag, PrecisionPlugin):
return self._precision_plugin_flag

if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
from lightning_graphcore import IPUAccelerator, IPUPrecision

# TODO: For the strategies that have a fixed precision class, we don't really need this logic
Expand All @@ -500,7 +500,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
if isinstance(self.accelerator, IPUAccelerator):
return IPUPrecision(self._precision_flag)

if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
from lightning_habana import HPUAccelerator, HPUPrecisionPlugin

if isinstance(self.accelerator, HPUAccelerator):
Expand Down Expand Up @@ -567,7 +567,7 @@ def _validate_precision_choice(self) -> None:
f"The `XLAAccelerator` can only be used with a `XLAPrecisionPlugin`,"
f" found: {self._precision_plugin_flag}."
)
if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
from lightning_habana import HPUAccelerator

if isinstance(self.accelerator, HPUAccelerator) and self._precision_flag not in (
Expand Down Expand Up @@ -622,7 +622,7 @@ def _lazy_init_strategy(self) -> None:
f" found {self.strategy.__class__.__name__}."
)

if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
from lightning_habana import HPUAccelerator, HPUParallelStrategy, SingleHPUStrategy

if isinstance(self.accelerator, HPUAccelerator) and not isinstance(
Expand All @@ -641,7 +641,7 @@ def is_distributed(self) -> bool:
DeepSpeedStrategy,
XLAStrategy,
]
if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
from lightning_habana import HPUParallelStrategy

distributed_strategies.append(HPUParallelStrategy)
Expand Down Expand Up @@ -694,7 +694,7 @@ def _register_external_accelerators_and_strategies() -> None:
if "bagua" not in StrategyRegistry:
BaguaStrategy.register_strategies(StrategyRegistry)

if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
from lightning_habana import HPUAccelerator, HPUParallelStrategy, SingleHPUStrategy

# TODO: Prevent registering multiple times
Expand All @@ -705,7 +705,7 @@ def _register_external_accelerators_and_strategies() -> None:
if "hpu_single" not in StrategyRegistry:
SingleHPUStrategy.register_strategies(StrategyRegistry)

if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
from lightning_graphcore import IPUAccelerator, IPUStrategy

# TODO: Prevent registering multiple times
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from lightning.pytorch.utilities.data import _is_dataloader_shuffled, _update_dataloader
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _LIGHTNING_GRAPHCORE_AVAILABLE
from lightning.pytorch.utilities.imports import _lightning_graphcore_available
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
Expand Down Expand Up @@ -165,7 +165,7 @@ def attach_datamodule(
datamodule.trainer = trainer

def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:
if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
from lightning_graphcore import IPUAccelerator

# `DistributedSampler` is never used with `poptorch.DataLoader`
Expand All @@ -190,7 +190,7 @@ def _prepare_dataloader(self, dataloader: object, shuffle: bool, mode: RunningSt
if not isinstance(dataloader, DataLoader):
return dataloader

if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
from lightning_graphcore import IPUAccelerator

# IPUs use a custom `poptorch.DataLoader` which we might need to convert to
Expand Down
10 changes: 5 additions & 5 deletions src/lightning/pytorch/trainer/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
XLAProfiler,
)
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _LIGHTNING_GRAPHCORE_AVAILABLE, _LIGHTNING_HABANA_AVAILABLE
from lightning.pytorch.utilities.imports import _lightning_graphcore_available, _lightning_habana_available
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn


Expand Down Expand Up @@ -158,7 +158,7 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, XLAAccelerator) else 0
rank_zero_info(f"TPU available: {XLAAccelerator.is_available()}, using: {num_tpu_cores} TPU cores")

if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
from lightning_graphcore import IPUAccelerator

num_ipus = trainer.num_devices if isinstance(trainer.accelerator, IPUAccelerator) else 0
Expand All @@ -168,7 +168,7 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
ipu_available = False
rank_zero_info(f"IPU available: {ipu_available}, using: {num_ipus} IPUs")

if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
from lightning_habana import HPUAccelerator

num_hpus = trainer.num_devices if isinstance(trainer.accelerator, HPUAccelerator) else 0
Expand All @@ -192,13 +192,13 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
if XLAAccelerator.is_available() and not isinstance(trainer.accelerator, XLAAccelerator):
rank_zero_warn("TPU available but not used. You can set it by doing `Trainer(accelerator='tpu')`.")

if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
from lightning_graphcore import IPUAccelerator

if IPUAccelerator.is_available() and not isinstance(trainer.accelerator, IPUAccelerator):
rank_zero_warn("IPU available but not used. You can set it by doing `Trainer(accelerator='ipu')`.")

if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
from lightning_habana import HPUAccelerator

if HPUAccelerator.is_available() and not isinstance(trainer.accelerator, HPUAccelerator):
Expand Down
14 changes: 12 additions & 2 deletions src/lightning/pytorch/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,15 @@ def _try_import_module(module_name: str) -> bool:
return False


_LIGHTNING_GRAPHCORE_AVAILABLE = RequirementCache("lightning-graphcore") and _try_import_module("lightning_graphcore")
_LIGHTNING_HABANA_AVAILABLE = RequirementCache("lightning-habana") and _try_import_module("lightning_habana")
@functools.lru_cache(maxsize=1)
def _lightning_graphcore_available() -> bool:
# This is defined as a function instead of a constant to avoid circular imports, because `lightning_graphcore`
# also imports Lightning
return bool(RequirementCache("lightning-graphcore")) and _try_import_module("lightning_graphcore")


@functools.lru_cache(maxsize=1)
def _lightning_habana_available() -> bool:
# This is defined as a function instead of a constant to avoid circular imports, because `lightning_habana`
# also imports Lightning
return bool(RequirementCache("lightning-habana")) and _try_import_module("lightning_habana")
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@
from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher
from lightning.pytorch.trainer.connectors.accelerator_connector import _AcceleratorConnector, _set_torch_flags
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _LIGHTNING_GRAPHCORE_AVAILABLE, _LIGHTNING_HABANA_AVAILABLE
from lightning.pytorch.utilities.imports import _lightning_graphcore_available, _lightning_habana_available
from tests_pytorch.conftest import mock_cuda_count, mock_mps_count, mock_tpu_available, mock_xla_available
from tests_pytorch.helpers.runif import RunIf

if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
from lightning_graphcore import IPUAccelerator, IPUStrategy

if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
from lightning_habana import HPUAccelerator, SingleHPUStrategy


Expand Down Expand Up @@ -935,7 +935,7 @@ def _mock_tpu_available(value):
assert connector.strategy.launcher.is_interactive_compatible

# Single/Multi IPU: strategy is the same
if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
with monkeypatch.context():
mock_cuda_count(monkeypatch, 0)
mock_mps_count(monkeypatch, 0)
Expand All @@ -949,7 +949,7 @@ def _mock_tpu_available(value):
assert connector.strategy.launcher is None

# Single HPU
if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
import lightning_habana

with monkeypatch.context():
Expand All @@ -967,7 +967,7 @@ def _mock_tpu_available(value):
monkeypatch.undo() # for some reason `.context()` is not working properly
_mock_interactive()

if not is_interactive and _LIGHTNING_HABANA_AVAILABLE: # HPU does not support interactive environments
if not is_interactive and _lightning_habana_available(): # HPU does not support interactive environments
from lightning_habana import HPUParallelStrategy

# Multi HPU
Expand Down