Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] update is_tpu_exists utils internal logic to rely on xmp.spawn #6719

Merged
merged 12 commits into from
Mar 29, 2021
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities"""
import numpy

import numpy
from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401
from pytorch_lightning.utilities.distributed import ( # noqa: F401
AllGatherGrad,
Expand Down
59 changes: 33 additions & 26 deletions pytorch_lightning/utilities/xla_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import queue as q
import traceback
from multiprocessing import Process, Queue
from typing import Optional

import torch
import torch.multiprocessing as mp

from pytorch_lightning.utilities.imports import _XLA_AVAILABLE

if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

#: define waiting time got checking TPU available in sec
TPU_CHECK_TIMEOUT = 100
TPU_CHECK_TIMEOUT = 25


def inner_f(queue, func, *args, **kwargs): # pragma: no cover
Expand Down Expand Up @@ -55,34 +59,29 @@ def wrapper(*args, **kwargs):
class XLADeviceUtils:
"""Used to detect the type of XLA device"""

TPU_AVAILABLE = None

@staticmethod
def _fetch_xla_device_type(device: torch.device) -> str:
"""
Returns XLA device type

Args:
device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0

Return:
Returns a str of the device hardware type. i.e TPU
"""
if _XLA_AVAILABLE:
return xm.xla_device_hw(device)
_TPU_AVAILABLE = None
tchaton marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def _is_device_tpu() -> bool:
@pl_multi_process
def _is_device_tpu() -> Optional[bool]:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""
Check if device is TPU

Return:
A boolean value indicating if the xla device is a TPU device or not
"""
if _XLA_AVAILABLE:
device = xm.xla_device()
device_type = XLADeviceUtils._fetch_xla_device_type(device)
return device_type == "TPU"

def _fn(process_idx: int, mp_queue):
try:
device = xm.xla_device()
mp_queue.put(device.type == 'xla')
except Exception:
mp_queue.put(False)

smp = mp.get_context("spawn")
queue = smp.SimpleQueue()
xmp.spawn(_fn, args=(queue, ), nprocs=1)
return queue.get()

@staticmethod
def xla_available() -> bool:
Expand All @@ -95,13 +94,21 @@ def xla_available() -> bool:
return _XLA_AVAILABLE

@staticmethod
def tpu_device_exists() -> bool:
def tpu_device_exists() -> Optional[bool]:
"""
Runs XLA device check within a separate process

Return:
A boolean value indicating if a TPU device exists on the system
"""
if XLADeviceUtils.TPU_AVAILABLE is None and _XLA_AVAILABLE:
XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)()
return XLADeviceUtils.TPU_AVAILABLE
if os.getenv("PL_TPU_AVAILABLE", '0') == "1":
XLADeviceUtils._TPU_AVAILABLE = True

if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE:

XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu()

if XLADeviceUtils._TPU_AVAILABLE:
os.environ["PL_TPU_AVAILABLE"] = '1'

return XLADeviceUtils._TPU_AVAILABLE
17 changes: 10 additions & 7 deletions tests/utilities/test_xla_device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import pytorch_lightning.utilities.xla_device as xla_utils
from pytorch_lightning.utilities import _XLA_AVAILABLE
from tests.helpers.runif import RunIf
from tests.helpers.utils import pl_multi_process_test


@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent")
Expand All @@ -29,18 +28,22 @@ def test_tpu_device_absence():


@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_device_presence():
"""Check tpu_device_exists returns True when TPU is available"""
assert xla_utils.XLADeviceUtils.tpu_device_exists() is True
assert xla_utils.XLADeviceUtils.tpu_device_exists()


@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 10)
@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 3)
def test_result_returns_within_timeout_seconds():
"""Check that pl_multi_process returns within 10 seconds"""
"""Check that pl_multi_process returns within 3 seconds"""

def fn():
time.sleep(xla_utils.TPU_CHECK_TIMEOUT * 0.5)
return True

start = time.time()
result = xla_utils.pl_multi_process(time.sleep)(xla_utils.TPU_CHECK_TIMEOUT * 1.25)
result = xla_utils.pl_multi_process(fn)()
tchaton marked this conversation as resolved.
Show resolved Hide resolved
end = time.time()
elapsed_time = int(end - start)
assert elapsed_time <= xla_utils.TPU_CHECK_TIMEOUT
assert result is False
assert result