diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index f4617c23da383..03981b0042eac 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities""" -import numpy +import numpy from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401 from pytorch_lightning.utilities.distributed import ( # noqa: F401 AllGatherGrad, diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index fcf56e9c679f4..294d3d2c5ec40 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -12,18 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools +import os import queue as q import traceback from multiprocessing import Process, Queue -import torch +import torch.multiprocessing as mp from pytorch_lightning.utilities.imports import _XLA_AVAILABLE if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm + import torch_xla.distributed.xla_multiprocessing as xmp + #: define waiting time got checking TPU available in sec -TPU_CHECK_TIMEOUT = 100 +TPU_CHECK_TIMEOUT = 25 def inner_f(queue, func, *args, **kwargs): # pragma: no cover @@ -55,23 +58,10 @@ def wrapper(*args, **kwargs): class XLADeviceUtils: """Used to detect the type of XLA device""" - TPU_AVAILABLE = None - - @staticmethod - def _fetch_xla_device_type(device: torch.device) -> str: - """ - Returns XLA device type - - Args: - device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0 - - Return: - Returns a str of the device hardware type. i.e TPU - """ - if _XLA_AVAILABLE: - return xm.xla_device_hw(device) + _TPU_AVAILABLE = False @staticmethod + @pl_multi_process def _is_device_tpu() -> bool: """ Check if device is TPU @@ -79,10 +69,18 @@ def _is_device_tpu() -> bool: Return: A boolean value indicating if the xla device is a TPU device or not """ - if _XLA_AVAILABLE: - device = xm.xla_device() - device_type = XLADeviceUtils._fetch_xla_device_type(device) - return device_type == "TPU" + + def _fn(_: int, mp_queue): + try: + device = xm.xla_device() + mp_queue.put(device.type == 'xla') + except Exception: + mp_queue.put(False) + + smp = mp.get_context("spawn") + queue = smp.SimpleQueue() + xmp.spawn(_fn, args=(queue, ), nprocs=1) + return queue.get() @staticmethod def xla_available() -> bool: @@ -102,6 +100,14 @@ def tpu_device_exists() -> bool: Return: A boolean value indicating if a TPU device exists on the system """ - if XLADeviceUtils.TPU_AVAILABLE is None and _XLA_AVAILABLE: - XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)() - return XLADeviceUtils.TPU_AVAILABLE + if os.getenv("PL_TPU_AVAILABLE", '0') == "1": + XLADeviceUtils._TPU_AVAILABLE = True + + if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE: + + XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu() + + if XLADeviceUtils._TPU_AVAILABLE: + os.environ["PL_TPU_AVAILABLE"] = '1' + + return XLADeviceUtils._TPU_AVAILABLE diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index 02be752e7e2fb..edca2777b578a 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -19,28 +19,35 @@ import pytorch_lightning.utilities.xla_device as xla_utils from pytorch_lightning.utilities import _XLA_AVAILABLE from tests.helpers.runif import RunIf -from tests.helpers.utils import pl_multi_process_test @pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent") def test_tpu_device_absence(): - """Check tpu_device_exists returns None when torch_xla is not available""" - assert xla_utils.XLADeviceUtils.tpu_device_exists() is None + """Check tpu_device_exists returns False when torch_xla is not available""" + assert not xla_utils.XLADeviceUtils.tpu_device_exists() @RunIf(tpu=True) -@pl_multi_process_test def test_tpu_device_presence(): """Check tpu_device_exists returns True when TPU is available""" - assert xla_utils.XLADeviceUtils.tpu_device_exists() is True + assert xla_utils.XLADeviceUtils.tpu_device_exists() -@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 10) +def sleep_fn(sleep_time: float) -> bool: + time.sleep(sleep_time) + return True + + +@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 3) +@pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present") def test_result_returns_within_timeout_seconds(): - """Check that pl_multi_process returns within 10 seconds""" + """Check that pl_multi_process returns within 3 seconds""" + fn = xla_utils.pl_multi_process(sleep_fn) + start = time.time() - result = xla_utils.pl_multi_process(time.sleep)(xla_utils.TPU_CHECK_TIMEOUT * 1.25) + result = fn(xla_utils.TPU_CHECK_TIMEOUT * 0.5) end = time.time() elapsed_time = int(end - start) + assert elapsed_time <= xla_utils.TPU_CHECK_TIMEOUT - assert result is False + assert result