Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] update is_tpu_exists utils internal logic to rely on xmp.spawn #6719

Merged
merged 12 commits into from
Mar 29, 2021
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities"""
import numpy

import numpy
from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401
from pytorch_lightning.utilities.distributed import ( # noqa: F401
AllGatherGrad,
Expand Down
54 changes: 30 additions & 24 deletions pytorch_lightning/utilities/xla_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import queue as q
import traceback
from multiprocessing import Process, Queue

import torch
import torch.multiprocessing as mp

from pytorch_lightning.utilities.imports import _XLA_AVAILABLE

if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

#: define waiting time got checking TPU available in sec
TPU_CHECK_TIMEOUT = 100
TPU_CHECK_TIMEOUT = 25


def inner_f(queue, func, *args, **kwargs): # pragma: no cover
Expand Down Expand Up @@ -55,34 +58,29 @@ def wrapper(*args, **kwargs):
class XLADeviceUtils:
"""Used to detect the type of XLA device"""

TPU_AVAILABLE = None

@staticmethod
def _fetch_xla_device_type(device: torch.device) -> str:
"""
Returns XLA device type
Args:
device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0
Return:
Returns a str of the device hardware type. i.e TPU
"""
if _XLA_AVAILABLE:
return xm.xla_device_hw(device)
_TPU_AVAILABLE = False

@staticmethod
@pl_multi_process
def _is_device_tpu() -> bool:
"""
Check if device is TPU
Return:
A boolean value indicating if the xla device is a TPU device or not
"""
if _XLA_AVAILABLE:
device = xm.xla_device()
device_type = XLADeviceUtils._fetch_xla_device_type(device)
return device_type == "TPU"

def _fn(_: int, mp_queue):
try:
device = xm.xla_device()
mp_queue.put(device.type == 'xla')
except Exception:
mp_queue.put(False)

smp = mp.get_context("spawn")
queue = smp.SimpleQueue()
xmp.spawn(_fn, args=(queue, ), nprocs=1)
return queue.get()

@staticmethod
def xla_available() -> bool:
Expand All @@ -102,6 +100,14 @@ def tpu_device_exists() -> bool:
Return:
A boolean value indicating if a TPU device exists on the system
"""
if XLADeviceUtils.TPU_AVAILABLE is None and _XLA_AVAILABLE:
XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)()
return XLADeviceUtils.TPU_AVAILABLE
if os.getenv("PL_TPU_AVAILABLE", '0') == "1":
XLADeviceUtils._TPU_AVAILABLE = True

if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE:

XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu()

if XLADeviceUtils._TPU_AVAILABLE:
os.environ["PL_TPU_AVAILABLE"] = '1'

return XLADeviceUtils._TPU_AVAILABLE
25 changes: 16 additions & 9 deletions tests/utilities/test_xla_device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,35 @@
import pytorch_lightning.utilities.xla_device as xla_utils
from pytorch_lightning.utilities import _XLA_AVAILABLE
from tests.helpers.runif import RunIf
from tests.helpers.utils import pl_multi_process_test


@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent")
def test_tpu_device_absence():
"""Check tpu_device_exists returns None when torch_xla is not available"""
assert xla_utils.XLADeviceUtils.tpu_device_exists() is None
"""Check tpu_device_exists returns False when torch_xla is not available"""
assert not xla_utils.XLADeviceUtils.tpu_device_exists()


@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_device_presence():
"""Check tpu_device_exists returns True when TPU is available"""
assert xla_utils.XLADeviceUtils.tpu_device_exists() is True
assert xla_utils.XLADeviceUtils.tpu_device_exists()


@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 10)
def sleep_fn(sleep_time: float) -> bool:
time.sleep(sleep_time)
return True


@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 3)
@pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present")
def test_result_returns_within_timeout_seconds():
"""Check that pl_multi_process returns within 10 seconds"""
"""Check that pl_multi_process returns within 3 seconds"""
fn = xla_utils.pl_multi_process(sleep_fn)

start = time.time()
result = xla_utils.pl_multi_process(time.sleep)(xla_utils.TPU_CHECK_TIMEOUT * 1.25)
result = fn(xla_utils.TPU_CHECK_TIMEOUT * 0.5)
end = time.time()
elapsed_time = int(end - start)

assert elapsed_time <= xla_utils.TPU_CHECK_TIMEOUT
assert result is False
assert result