From eed1d50768b57fedc8f998b1b45bccfbdbede5a6 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 29 Mar 2021 11:29:48 +0000 Subject: [PATCH 01/12] update_logic --- pytorch_lightning/utilities/__init__.py | 2 +- pytorch_lightning/utilities/xla_device.py | 59 +++++++++++++---------- tests/utilities/test_xla_device_utils.py | 12 +++-- 3 files changed, 42 insertions(+), 31 deletions(-) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index f4617c23da383..03981b0042eac 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities""" -import numpy +import numpy from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401 from pytorch_lightning.utilities.distributed import ( # noqa: F401 AllGatherGrad, diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index fcf56e9c679f4..b7dabdf24cccf 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -11,19 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +from typing import Optional import functools import queue as q import traceback from multiprocessing import Process, Queue import torch +import torch.multiprocessing as mp from pytorch_lightning.utilities.imports import _XLA_AVAILABLE if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm + import torch_xla.distributed.xla_multiprocessing as xmp + #: define waiting time got checking TPU available in sec -TPU_CHECK_TIMEOUT = 100 + +TPU_CHECK_TIMEOUT = 25 def inner_f(queue, func, *args, **kwargs): # pragma: no cover @@ -55,34 +61,29 @@ def wrapper(*args, **kwargs): class XLADeviceUtils: """Used to detect the type of XLA device""" - TPU_AVAILABLE = None - - @staticmethod - def _fetch_xla_device_type(device: torch.device) -> str: - """ - Returns XLA device type - - Args: - device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0 - - Return: - Returns a str of the device hardware type. i.e TPU - """ - if _XLA_AVAILABLE: - return xm.xla_device_hw(device) + _TPU_AVAILABLE = None @staticmethod - def _is_device_tpu() -> bool: + @pl_multi_process + def _is_device_tpu() -> Optional[bool]: """ Check if device is TPU Return: A boolean value indicating if the xla device is a TPU device or not """ - if _XLA_AVAILABLE: - device = xm.xla_device() - device_type = XLADeviceUtils._fetch_xla_device_type(device) - return device_type == "TPU" + def _fn(process_idx: int, mp_queue): + try: + device = xm.xla_device() + mp_queue.put(device.type == 'xla') + except Exception: + mp_queue.put(False) + + smp = mp.get_context("spawn") + queue = smp.SimpleQueue() + xmp.spawn(_fn, args=(queue, ), nprocs=1) + return queue.get() + @staticmethod def xla_available() -> bool: @@ -95,13 +96,21 @@ def xla_available() -> bool: return _XLA_AVAILABLE @staticmethod - def tpu_device_exists() -> bool: + def tpu_device_exists() -> Optional[bool]: """ Runs XLA device check within a separate process Return: A boolean value indicating if a TPU device exists on the system """ - if XLADeviceUtils.TPU_AVAILABLE is None and _XLA_AVAILABLE: - XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)() - return XLADeviceUtils.TPU_AVAILABLE + if os.getenv("PL_TPU_AVAILABLE", '0') == "1": + XLADeviceUtils._TPU_AVAILABLE = True + + if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE: + + XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu() + + if XLADeviceUtils._TPU_AVAILABLE: + os.environ["PL_TPU_AVAILABLE"] = '1' + + return XLADeviceUtils._TPU_AVAILABLE diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index 02be752e7e2fb..0a1bb540bf7a6 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -29,18 +29,20 @@ def test_tpu_device_absence(): @RunIf(tpu=True) -@pl_multi_process_test def test_tpu_device_presence(): """Check tpu_device_exists returns True when TPU is available""" - assert xla_utils.XLADeviceUtils.tpu_device_exists() is True + assert xla_utils.XLADeviceUtils.tpu_device_exists() -@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 10) +@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 3) def test_result_returns_within_timeout_seconds(): """Check that pl_multi_process returns within 10 seconds""" + def fn(): + time.sleep(xla_utils.TPU_CHECK_TIMEOUT * 0.5) + return True start = time.time() - result = xla_utils.pl_multi_process(time.sleep)(xla_utils.TPU_CHECK_TIMEOUT * 1.25) + result = xla_utils.pl_multi_process(fn)() end = time.time() elapsed_time = int(end - start) assert elapsed_time <= xla_utils.TPU_CHECK_TIMEOUT - assert result is False + assert result From 0f8c08bf9c5ae436bfaeef74d3dcde6aedff8434 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 29 Mar 2021 12:31:20 +0100 Subject: [PATCH 02/12] update --- pytorch_lightning/utilities/xla_device.py | 14 ++++++-------- tests/utilities/test_xla_device_utils.py | 3 ++- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index b7dabdf24cccf..879e5a4ecbb1a 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from typing import Optional import functools +import os import queue as q import traceback from multiprocessing import Process, Queue +from typing import Optional -import torch import torch.multiprocessing as mp from pytorch_lightning.utilities.imports import _XLA_AVAILABLE @@ -28,7 +27,6 @@ import torch_xla.distributed.xla_multiprocessing as xmp #: define waiting time got checking TPU available in sec - TPU_CHECK_TIMEOUT = 25 @@ -72,6 +70,7 @@ def _is_device_tpu() -> Optional[bool]: Return: A boolean value indicating if the xla device is a TPU device or not """ + def _fn(process_idx: int, mp_queue): try: device = xm.xla_device() @@ -83,7 +82,6 @@ def _fn(process_idx: int, mp_queue): queue = smp.SimpleQueue() xmp.spawn(_fn, args=(queue, ), nprocs=1) return queue.get() - @staticmethod def xla_available() -> bool: @@ -105,11 +103,11 @@ def tpu_device_exists() -> Optional[bool]: """ if os.getenv("PL_TPU_AVAILABLE", '0') == "1": XLADeviceUtils._TPU_AVAILABLE = True - + if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE: - + XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu() - + if XLADeviceUtils._TPU_AVAILABLE: os.environ["PL_TPU_AVAILABLE"] = '1' diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index 0a1bb540bf7a6..2da5c3cd8bf9f 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -19,7 +19,6 @@ import pytorch_lightning.utilities.xla_device as xla_utils from pytorch_lightning.utilities import _XLA_AVAILABLE from tests.helpers.runif import RunIf -from tests.helpers.utils import pl_multi_process_test @pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent") @@ -37,9 +36,11 @@ def test_tpu_device_presence(): @patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 3) def test_result_returns_within_timeout_seconds(): """Check that pl_multi_process returns within 10 seconds""" + def fn(): time.sleep(xla_utils.TPU_CHECK_TIMEOUT * 0.5) return True + start = time.time() result = xla_utils.pl_multi_process(fn)() end = time.time() From fdfb6a4c9185ed6d45e6885a6ae5ba022619be7a Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 29 Mar 2021 17:18:02 +0530 Subject: [PATCH 03/12] Update tests/utilities/test_xla_device_utils.py --- tests/utilities/test_xla_device_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index 2da5c3cd8bf9f..ff8cc8a7f5453 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -35,7 +35,7 @@ def test_tpu_device_presence(): @patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 3) def test_result_returns_within_timeout_seconds(): - """Check that pl_multi_process returns within 10 seconds""" + """Check that pl_multi_process returns within 3 seconds""" def fn(): time.sleep(xla_utils.TPU_CHECK_TIMEOUT * 0.5) From acd348152cb867496824845cd3f6c00c339bd36d Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 29 Mar 2021 13:21:07 +0100 Subject: [PATCH 04/12] Update pytorch_lightning/utilities/xla_device.py Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> --- pytorch_lightning/utilities/xla_device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index 879e5a4ecbb1a..901c9f23ac563 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -59,7 +59,7 @@ def wrapper(*args, **kwargs): class XLADeviceUtils: """Used to detect the type of XLA device""" - _TPU_AVAILABLE = None + _TPU_AVAILABLE = False @staticmethod @pl_multi_process From 47cbd50dfdcbed064e3e4782c38ae21629d3c780 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 29 Mar 2021 13:21:13 +0100 Subject: [PATCH 05/12] Update pytorch_lightning/utilities/xla_device.py Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> --- pytorch_lightning/utilities/xla_device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index 901c9f23ac563..10e08628a8f19 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -63,7 +63,7 @@ class XLADeviceUtils: @staticmethod @pl_multi_process - def _is_device_tpu() -> Optional[bool]: + def _is_device_tpu() -> bool: """ Check if device is TPU From fff33b8364e0a930ae037479ee20fbd4483c046e Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 29 Mar 2021 12:26:47 +0000 Subject: [PATCH 06/12] update test --- tests/utilities/test_xla_device_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index ff8cc8a7f5453..bb0dab5b64522 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -33,16 +33,17 @@ def test_tpu_device_presence(): assert xla_utils.XLADeviceUtils.tpu_device_exists() +def sleep_fn(sleep_time: float) -> int: + time.sleep(sleep_time) + return True + + @patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 3) def test_result_returns_within_timeout_seconds(): """Check that pl_multi_process returns within 3 seconds""" - def fn(): - time.sleep(xla_utils.TPU_CHECK_TIMEOUT * 0.5) - return True - start = time.time() - result = xla_utils.pl_multi_process(fn)() + result = xla_utils.pl_multi_process(sleep_fn)(xla_utils.TPU_CHECK_TIMEOUT * 0.5) end = time.time() elapsed_time = int(end - start) assert elapsed_time <= xla_utils.TPU_CHECK_TIMEOUT From 521d72737a148cd533e581e4fbe3ed559f83add5 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 29 Mar 2021 18:02:27 +0530 Subject: [PATCH 07/12] Update tests/utilities/test_xla_device_utils.py --- tests/utilities/test_xla_device_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index bb0dab5b64522..7eafbbc5275ca 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -33,7 +33,7 @@ def test_tpu_device_presence(): assert xla_utils.XLADeviceUtils.tpu_device_exists() -def sleep_fn(sleep_time: float) -> int: +def sleep_fn(sleep_time: float) -> bool: time.sleep(sleep_time) return True From dd67ad8f21e3cf8d7f6764df140cd67d65de98fd Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 29 Mar 2021 13:38:19 +0100 Subject: [PATCH 08/12] update --- pytorch_lightning/utilities/xla_device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index 10e08628a8f19..f09beecdf0ab5 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -94,7 +94,7 @@ def xla_available() -> bool: return _XLA_AVAILABLE @staticmethod - def tpu_device_exists() -> Optional[bool]: + def tpu_device_exists() -> bool: """ Runs XLA device check within a separate process From cfecd8f4952ef6f2068d0480e376ecab34c3f925 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 29 Mar 2021 14:57:12 +0200 Subject: [PATCH 09/12] Apply fix --- tests/utilities/test_xla_device_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index 7eafbbc5275ca..bf731ddc9a182 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -24,7 +24,7 @@ @pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent") def test_tpu_device_absence(): """Check tpu_device_exists returns None when torch_xla is not available""" - assert xla_utils.XLADeviceUtils.tpu_device_exists() is None + assert not xla_utils.XLADeviceUtils.tpu_device_exists() @RunIf(tpu=True) @@ -41,10 +41,12 @@ def sleep_fn(sleep_time: float) -> bool: @patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 3) def test_result_returns_within_timeout_seconds(): """Check that pl_multi_process returns within 3 seconds""" + fn = xla_utils.pl_multi_process(sleep_fn) start = time.time() - result = xla_utils.pl_multi_process(sleep_fn)(xla_utils.TPU_CHECK_TIMEOUT * 0.5) + result = fn(xla_utils.TPU_CHECK_TIMEOUT * 0.5) end = time.time() elapsed_time = int(end - start) + assert elapsed_time <= xla_utils.TPU_CHECK_TIMEOUT assert result From 4509771226cae2ac1e1f415334aa197415b0209d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 29 Mar 2021 14:58:25 +0200 Subject: [PATCH 10/12] Docstring --- tests/utilities/test_xla_device_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index bf731ddc9a182..46fe6403c93ae 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -23,7 +23,7 @@ @pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent") def test_tpu_device_absence(): - """Check tpu_device_exists returns None when torch_xla is not available""" + """Check tpu_device_exists returns False when torch_xla is not available""" assert not xla_utils.XLADeviceUtils.tpu_device_exists() From d178e1d24d2675a4deabc2fee67aa8fc8a06ad50 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 29 Mar 2021 15:02:55 +0200 Subject: [PATCH 11/12] flake8 --- pytorch_lightning/utilities/xla_device.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index f09beecdf0ab5..294d3d2c5ec40 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -16,7 +16,6 @@ import queue as q import traceback from multiprocessing import Process, Queue -from typing import Optional import torch.multiprocessing as mp @@ -71,7 +70,7 @@ def _is_device_tpu() -> bool: A boolean value indicating if the xla device is a TPU device or not """ - def _fn(process_idx: int, mp_queue): + def _fn(_: int, mp_queue): try: device = xm.xla_device() mp_queue.put(device.type == 'xla') From 40015ddd1030bff70f32a2ab7b1f498a68e83ffb Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 29 Mar 2021 17:18:54 +0100 Subject: [PATCH 12/12] update --- tests/utilities/test_xla_device_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index 46fe6403c93ae..edca2777b578a 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -39,6 +39,7 @@ def sleep_fn(sleep_time: float) -> bool: @patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 3) +@pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present") def test_result_returns_within_timeout_seconds(): """Check that pl_multi_process returns within 3 seconds""" fn = xla_utils.pl_multi_process(sleep_fn)