diff --git a/CHANGELOG.md b/CHANGELOG.md index a2ea14b23d166..8a20ee5914854 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error on TPUs when there was no `ModelCheckpoint` ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654)) - Fixed `trainer.test` freeze on TPUs ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654)) - Fixed a bug where gradients were disabled after calling `Trainer.predict` ([#6657](https://github.com/PyTorchLightning/pytorch-lightning/pull/6657)) +- Fixed bug where no TPUs were detected in a TPU pod env ([#6719](https://github.com/PyTorchLightning/pytorch-lightning/pull/6719)) ## [1.2.5] - 2021-03-23 diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index cf3aa06f305b8..e24e4a0db560a 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -14,7 +14,6 @@ """General utilities""" import numpy - from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401 from pytorch_lightning.utilities.distributed import ( # noqa: F401 AllGatherGrad, diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index fcf56e9c679f4..294d3d2c5ec40 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -12,18 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools +import os import queue as q import traceback from multiprocessing import Process, Queue -import torch +import torch.multiprocessing as mp from pytorch_lightning.utilities.imports import _XLA_AVAILABLE if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm + import torch_xla.distributed.xla_multiprocessing as xmp + #: define waiting time got checking TPU available in sec -TPU_CHECK_TIMEOUT = 100 +TPU_CHECK_TIMEOUT = 25 def inner_f(queue, func, *args, **kwargs): # pragma: no cover @@ -55,23 +58,10 @@ def wrapper(*args, **kwargs): class XLADeviceUtils: """Used to detect the type of XLA device""" - TPU_AVAILABLE = None - - @staticmethod - def _fetch_xla_device_type(device: torch.device) -> str: - """ - Returns XLA device type - - Args: - device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0 - - Return: - Returns a str of the device hardware type. i.e TPU - """ - if _XLA_AVAILABLE: - return xm.xla_device_hw(device) + _TPU_AVAILABLE = False @staticmethod + @pl_multi_process def _is_device_tpu() -> bool: """ Check if device is TPU @@ -79,10 +69,18 @@ def _is_device_tpu() -> bool: Return: A boolean value indicating if the xla device is a TPU device or not """ - if _XLA_AVAILABLE: - device = xm.xla_device() - device_type = XLADeviceUtils._fetch_xla_device_type(device) - return device_type == "TPU" + + def _fn(_: int, mp_queue): + try: + device = xm.xla_device() + mp_queue.put(device.type == 'xla') + except Exception: + mp_queue.put(False) + + smp = mp.get_context("spawn") + queue = smp.SimpleQueue() + xmp.spawn(_fn, args=(queue, ), nprocs=1) + return queue.get() @staticmethod def xla_available() -> bool: @@ -102,6 +100,14 @@ def tpu_device_exists() -> bool: Return: A boolean value indicating if a TPU device exists on the system """ - if XLADeviceUtils.TPU_AVAILABLE is None and _XLA_AVAILABLE: - XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)() - return XLADeviceUtils.TPU_AVAILABLE + if os.getenv("PL_TPU_AVAILABLE", '0') == "1": + XLADeviceUtils._TPU_AVAILABLE = True + + if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE: + + XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu() + + if XLADeviceUtils._TPU_AVAILABLE: + os.environ["PL_TPU_AVAILABLE"] = '1' + + return XLADeviceUtils._TPU_AVAILABLE diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index 872b49ef48635..b0407d1fca6b2 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest +import torch + from pytorch_lightning import Trainer from pytorch_lightning.plugins import DDPPlugin from tests.helpers import BoringModel @@ -26,6 +29,7 @@ def __init__(self, **kwargs): @RunIf(skip_windows=True) +@pytest.mark.skipif(torch.cuda.is_available(), reason="RuntimeError: Tensors must be CUDA and dense") def test_sync_batchnorm_set(tmpdir): """Tests if sync_batchnorm is automatically set for custom plugin.""" model = BoringModel() diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index 73b11b48267ce..edca2777b578a 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -17,29 +17,37 @@ import pytest import pytorch_lightning.utilities.xla_device as xla_utils -from pytorch_lightning.utilities import _TPU_AVAILABLE, _XLA_AVAILABLE -from tests.helpers.utils import pl_multi_process_test +from pytorch_lightning.utilities import _XLA_AVAILABLE +from tests.helpers.runif import RunIf @pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent") def test_tpu_device_absence(): - """Check tpu_device_exists returns None when torch_xla is not available""" - assert xla_utils.XLADeviceUtils.tpu_device_exists() is None + """Check tpu_device_exists returns False when torch_xla is not available""" + assert not xla_utils.XLADeviceUtils.tpu_device_exists() -@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires torch_xla to be installed") -@pl_multi_process_test +@RunIf(tpu=True) def test_tpu_device_presence(): """Check tpu_device_exists returns True when TPU is available""" - assert xla_utils.XLADeviceUtils.tpu_device_exists() is True + assert xla_utils.XLADeviceUtils.tpu_device_exists() -@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 10) +def sleep_fn(sleep_time: float) -> bool: + time.sleep(sleep_time) + return True + + +@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 3) +@pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present") def test_result_returns_within_timeout_seconds(): - """Check that pl_multi_process returns within 10 seconds""" + """Check that pl_multi_process returns within 3 seconds""" + fn = xla_utils.pl_multi_process(sleep_fn) + start = time.time() - result = xla_utils.pl_multi_process(time.sleep)(xla_utils.TPU_CHECK_TIMEOUT * 1.25) + result = fn(xla_utils.TPU_CHECK_TIMEOUT * 0.5) end = time.time() elapsed_time = int(end - start) + assert elapsed_time <= xla_utils.TPU_CHECK_TIMEOUT - assert result is False + assert result