From c70059067f268699188c2b0cb850140f039d365d Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 29 Mar 2021 18:59:20 +0100 Subject: [PATCH] [TPU] update is_tpu_exists utils internal logic to rely on xmp.spawn (#6719) * update_logic * update * Update tests/utilities/test_xla_device_utils.py * Update pytorch_lightning/utilities/xla_device.py Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> * Update pytorch_lightning/utilities/xla_device.py Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> * update test * Update tests/utilities/test_xla_device_utils.py * update * Apply fix * Docstring * flake8 * update Co-authored-by: Your Name Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Co-authored-by: Carlos Mocholi --- CHANGELOG.md | 1 + pytorch_lightning/utilities/__init__.py | 1 - pytorch_lightning/utilities/xla_device.py | 54 +++++++++++++---------- tests/utilities/test_xla_device_utils.py | 30 ++++++++----- 4 files changed, 50 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a2ea14b23d166..8a20ee5914854 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error on TPUs when there was no `ModelCheckpoint` ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654)) - Fixed `trainer.test` freeze on TPUs ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654)) - Fixed a bug where gradients were disabled after calling `Trainer.predict` ([#6657](https://github.com/PyTorchLightning/pytorch-lightning/pull/6657)) +- Fixed bug where no TPUs were detected in a TPU pod env ([#6719](https://github.com/PyTorchLightning/pytorch-lightning/pull/6719)) ## [1.2.5] - 2021-03-23 diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index cf3aa06f305b8..e24e4a0db560a 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -14,7 +14,6 @@ """General utilities""" import numpy - from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401 from pytorch_lightning.utilities.distributed import ( # noqa: F401 AllGatherGrad, diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index fcf56e9c679f4..294d3d2c5ec40 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -12,18 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools +import os import queue as q import traceback from multiprocessing import Process, Queue -import torch +import torch.multiprocessing as mp from pytorch_lightning.utilities.imports import _XLA_AVAILABLE if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm + import torch_xla.distributed.xla_multiprocessing as xmp + #: define waiting time got checking TPU available in sec -TPU_CHECK_TIMEOUT = 100 +TPU_CHECK_TIMEOUT = 25 def inner_f(queue, func, *args, **kwargs): # pragma: no cover @@ -55,23 +58,10 @@ def wrapper(*args, **kwargs): class XLADeviceUtils: """Used to detect the type of XLA device""" - TPU_AVAILABLE = None - - @staticmethod - def _fetch_xla_device_type(device: torch.device) -> str: - """ - Returns XLA device type - - Args: - device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0 - - Return: - Returns a str of the device hardware type. i.e TPU - """ - if _XLA_AVAILABLE: - return xm.xla_device_hw(device) + _TPU_AVAILABLE = False @staticmethod + @pl_multi_process def _is_device_tpu() -> bool: """ Check if device is TPU @@ -79,10 +69,18 @@ def _is_device_tpu() -> bool: Return: A boolean value indicating if the xla device is a TPU device or not """ - if _XLA_AVAILABLE: - device = xm.xla_device() - device_type = XLADeviceUtils._fetch_xla_device_type(device) - return device_type == "TPU" + + def _fn(_: int, mp_queue): + try: + device = xm.xla_device() + mp_queue.put(device.type == 'xla') + except Exception: + mp_queue.put(False) + + smp = mp.get_context("spawn") + queue = smp.SimpleQueue() + xmp.spawn(_fn, args=(queue, ), nprocs=1) + return queue.get() @staticmethod def xla_available() -> bool: @@ -102,6 +100,14 @@ def tpu_device_exists() -> bool: Return: A boolean value indicating if a TPU device exists on the system """ - if XLADeviceUtils.TPU_AVAILABLE is None and _XLA_AVAILABLE: - XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)() - return XLADeviceUtils.TPU_AVAILABLE + if os.getenv("PL_TPU_AVAILABLE", '0') == "1": + XLADeviceUtils._TPU_AVAILABLE = True + + if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE: + + XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu() + + if XLADeviceUtils._TPU_AVAILABLE: + os.environ["PL_TPU_AVAILABLE"] = '1' + + return XLADeviceUtils._TPU_AVAILABLE diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index 73b11b48267ce..edca2777b578a 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -17,29 +17,37 @@ import pytest import pytorch_lightning.utilities.xla_device as xla_utils -from pytorch_lightning.utilities import _TPU_AVAILABLE, _XLA_AVAILABLE -from tests.helpers.utils import pl_multi_process_test +from pytorch_lightning.utilities import _XLA_AVAILABLE +from tests.helpers.runif import RunIf @pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent") def test_tpu_device_absence(): - """Check tpu_device_exists returns None when torch_xla is not available""" - assert xla_utils.XLADeviceUtils.tpu_device_exists() is None + """Check tpu_device_exists returns False when torch_xla is not available""" + assert not xla_utils.XLADeviceUtils.tpu_device_exists() -@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires torch_xla to be installed") -@pl_multi_process_test +@RunIf(tpu=True) def test_tpu_device_presence(): """Check tpu_device_exists returns True when TPU is available""" - assert xla_utils.XLADeviceUtils.tpu_device_exists() is True + assert xla_utils.XLADeviceUtils.tpu_device_exists() -@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 10) +def sleep_fn(sleep_time: float) -> bool: + time.sleep(sleep_time) + return True + + +@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 3) +@pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present") def test_result_returns_within_timeout_seconds(): - """Check that pl_multi_process returns within 10 seconds""" + """Check that pl_multi_process returns within 3 seconds""" + fn = xla_utils.pl_multi_process(sleep_fn) + start = time.time() - result = xla_utils.pl_multi_process(time.sleep)(xla_utils.TPU_CHECK_TIMEOUT * 1.25) + result = fn(xla_utils.TPU_CHECK_TIMEOUT * 0.5) end = time.time() elapsed_time = int(end - start) + assert elapsed_time <= xla_utils.TPU_CHECK_TIMEOUT - assert result is False + assert result