Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] Correct the check for TPU device in a pod environment #6755

Closed
wants to merge 12 commits into from
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities import _XLA_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import move_data_to_device

if _TPU_AVAILABLE:
if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm


Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn, _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities import _XLA_AVAILABLE, rank_zero_warn, _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.apply_func import apply_to_collection

if _TPU_AVAILABLE:
if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as xla_pl
import torch_xla.distributed.xla_multiprocessing as xmp
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@
_APEX_AVAILABLE,
_HOROVOD_AVAILABLE,
_NATIVE_AMP_AVAILABLE,
_TPU_AVAILABLE,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove _TPU_AVAILABLE ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced TPU_AVAILABLE flag with a call XLADeviceUtils.tpu_device_exists(), because now it spawns to decide the value, and I see python complaints (something not frozen yet) to spawn at the time of module creation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, add back _TPU_AVAILABLE. We use this as a common imports API in Lightning and it won't be backward compatible.

AMPType,
device_parser,
DeviceType,
DistributedType,
rank_zero_only,
XLADeviceUtils,
)
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -555,7 +555,8 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):

rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}')
num_cores = self.tpu_cores if self.tpu_cores is not None else 0
rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores')
tpu_available = XLADeviceUtils.tpu_device_exists()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use _TPU_AVAILABLE there.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_TPU_AVAILABLE is replaced by XLADeviceUtils.tpu_device_exists()

rank_zero_info(f'TPU available: {tpu_available}, using: {num_cores} TPU cores')

if torch.cuda.is_available() and self._device_type != DeviceType.GPU:
rank_zero_warn(
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing
from pytorch_lightning.utilities import AMPType, DeviceType, parsing, XLADeviceUtils
from pytorch_lightning.utilities.distributed import rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
Expand Down Expand Up @@ -408,7 +408,7 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_
optimizer,
opt_idx,
train_step_and_backward_closure,
on_tpu=self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE,
on_tpu=self.trainer._device_type == DeviceType.TPU and XLADeviceUtils.tpu_device_exists(),
using_native_amp=using_native_amp,
using_lbfgs=is_lbfgs,
)
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401
from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: F401

_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps
4 changes: 2 additions & 2 deletions pytorch_lightning/utilities/device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch

from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities import XLADeviceUtils
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand Down Expand Up @@ -109,7 +109,7 @@ def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[List[int
if not _tpu_cores_valid(tpu_cores):
raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]")

if tpu_cores is not None and not _TPU_AVAILABLE:
if tpu_cores is not None and not XLADeviceUtils.tpu_device_exists():
raise MisconfigurationException('No TPU devices were found.')

return tpu_cores
Expand Down
57 changes: 28 additions & 29 deletions pytorch_lightning/utilities/xla_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import os
import queue as q
import traceback
from multiprocessing import Process, Queue

import torch.multiprocessing as mp

Expand All @@ -26,31 +25,33 @@
import torch_xla.distributed.xla_multiprocessing as xmp

#: define waiting time got checking TPU available in sec
TPU_CHECK_TIMEOUT = 25
TPU_CHECK_TIMEOUT = 120


def inner_f(queue, func, *args, **kwargs): # pragma: no cover
try:
queue.put(func(*args, **kwargs))
# todo: specify the possible exception
except Exception:
traceback.print_exc()
queue.put(None)
def inner_f(index, queue, func, *args): # pragma: no cover
queue.put(func(index, *args))


def pl_multi_process(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
queue = Queue()
proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs)
proc.start()
proc.join(TPU_CHECK_TIMEOUT)
def wrapper(*args):
smp = mp.get_context("spawn")
queue = smp.Queue()
cxt = xmp.spawn(inner_f, args=(queue, func, *args), join=False)

# errors in the subprocesses are caught and saved in the error_queues
# inside the context, but we don't bother to check them.
if not cxt.join(TPU_CHECK_TIMEOUT):
for proc in cxt.processes:
if proc.is_alive():
proc.terminate()
proc.join()

try:
return queue.get_nowait()
except q.Empty:
traceback.print_exc()
return False
return None

return wrapper

Expand All @@ -61,26 +62,24 @@ class XLADeviceUtils:
_TPU_AVAILABLE = False

@staticmethod
@pl_multi_process
def _is_device_tpu() -> bool:
def _is_device_tpu(index) -> bool:
"""
Check if device is TPU

Return:
A boolean value indicating if the xla device is a TPU device or not
"""
if not _XLA_AVAILABLE:
return False

def _fn(_: int, mp_queue):
try:
device = xm.xla_device()
mp_queue.put(device.type == 'xla')
except Exception:
mp_queue.put(False)
try:
device = xm.xla_device()
return device.type == 'xla'

smp = mp.get_context("spawn")
queue = smp.SimpleQueue()
xmp.spawn(_fn, args=(queue, ), nprocs=1)
return queue.get()
# Missing XLA Configuration
except RuntimeError as e:
Copy link
Contributor

@tchaton tchaton Mar 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments can be removed if not used

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay

traceback.print_exc()
return False

@staticmethod
def xla_available() -> bool:
Expand All @@ -105,7 +104,7 @@ def tpu_device_exists() -> bool:

if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE:

XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu()
XLADeviceUtils._TPU_AVAILABLE = bool(pl_multi_process(XLADeviceUtils._is_device_tpu)())

if XLADeviceUtils._TPU_AVAILABLE:
os.environ["PL_TPU_AVAILABLE"] = '1'
Expand Down