From fd4f47151b21fe29a19beac1c11b8d86a030e4db Mon Sep 17 00:00:00 2001 From: Jiasen Wu Date: Tue, 30 Mar 2021 22:32:11 +0200 Subject: [PATCH 01/12] fixup in a v2-32 env --- pytorch_lightning/utilities/xla_device.py | 51 ++++++++++++----------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index 294d3d2c5ec40..493d3f4971235 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -15,7 +15,6 @@ import os import queue as q import traceback -from multiprocessing import Process, Queue import torch.multiprocessing as mp @@ -29,28 +28,30 @@ TPU_CHECK_TIMEOUT = 25 -def inner_f(queue, func, *args, **kwargs): # pragma: no cover - try: - queue.put(func(*args, **kwargs)) - # todo: specify the possible exception - except Exception: - traceback.print_exc() - queue.put(None) +def inner_f(index, queue, func, *args): # pragma: no cover + queue.put(func(index, *args)) def pl_multi_process(func): @functools.wraps(func) - def wrapper(*args, **kwargs): - queue = Queue() - proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs) - proc.start() - proc.join(TPU_CHECK_TIMEOUT) + def wrapper(*args): + smp = mp.get_context("spawn") + queue = smp.Queue() + cxt = xmp.spawn(inner_f, args=(queue, func, *args), join=False) + + # errors in the subprocesses are caught and saved in the error_queues + # inside the context, but we don't bother to check them. + if not cxt.join(TPU_CHECK_TIMEOUT): + for proc in cxt.processes: + if proc.is_alive(): + proc.terminate() + proc.join() + try: return queue.get_nowait() except q.Empty: - traceback.print_exc() - return False + return None return wrapper @@ -69,18 +70,18 @@ def _is_device_tpu() -> bool: Return: A boolean value indicating if the xla device is a TPU device or not """ + if not _XLA_AVAILABLE: + return False - def _fn(_: int, mp_queue): - try: - device = xm.xla_device() - mp_queue.put(device.type == 'xla') - except Exception: - mp_queue.put(False) + try: + device = xm.xla_device() + device_type = XLADeviceUtils._fetch_xla_device_type(device) + return device_type == "TPU" - smp = mp.get_context("spawn") - queue = smp.SimpleQueue() - xmp.spawn(_fn, args=(queue, ), nprocs=1) - return queue.get() + # Missing XLA Configuration + except RuntimeError as e: + traceback.print_exc() + return False @staticmethod def xla_available() -> bool: From 01c274cd4dfbd7ed07807aae8273a5261fb51de4 Mon Sep 17 00:00:00 2001 From: Jiasen Wu Date: Tue, 30 Mar 2021 22:42:39 +0200 Subject: [PATCH 02/12] try --- pytorch_lightning/utilities/xla_device.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index 493d3f4971235..bc00f557245a7 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -62,7 +62,6 @@ class XLADeviceUtils: _TPU_AVAILABLE = False @staticmethod - @pl_multi_process def _is_device_tpu() -> bool: """ Check if device is TPU @@ -106,7 +105,7 @@ def tpu_device_exists() -> bool: if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE: - XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu() + XLADeviceUtils._TPU_AVAILABLE = bool(pl_multi_process(XLADeviceUtils._is_device_tpu)()) if XLADeviceUtils._TPU_AVAILABLE: os.environ["PL_TPU_AVAILABLE"] = '1' From 87ebd3197f16c99e6dff55f8cbb651608ad94e16 Mon Sep 17 00:00:00 2001 From: Jiasen Wu Date: Tue, 30 Mar 2021 23:00:53 +0200 Subject: [PATCH 03/12] remove global flag _TPU_AVAILABLE --- pytorch_lightning/plugins/training_type/single_tpu.py | 4 ++-- pytorch_lightning/plugins/training_type/tpu_spawn.py | 4 ++-- .../trainer/connectors/accelerator_connector.py | 5 +++-- pytorch_lightning/trainer/training_loop.py | 4 ++-- pytorch_lightning/utilities/__init__.py | 2 -- pytorch_lightning/utilities/device_parser.py | 4 ++-- 6 files changed, 11 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index b8d670ff16881..767b68038626a 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -19,10 +19,10 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle -from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities import _XLA_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import move_data_to_device -if _TPU_AVAILABLE: +if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index a29310f65f724..73a94a1c9f577 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -23,13 +23,13 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn, _OMEGACONF_AVAILABLE +from pytorch_lightning.utilities import _XLA_AVAILABLE, rank_zero_warn, _OMEGACONF_AVAILABLE from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.apply_func import apply_to_collection -if _TPU_AVAILABLE: +if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as xla_pl import torch_xla.distributed.xla_multiprocessing as xmp diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 30d2b48975a84..ec19595710835 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -54,12 +54,12 @@ _APEX_AVAILABLE, _HOROVOD_AVAILABLE, _NATIVE_AMP_AVAILABLE, - _TPU_AVAILABLE, AMPType, device_parser, DeviceType, DistributedType, rank_zero_only, + XLADeviceUtils, ) from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -555,7 +555,8 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}') num_cores = self.tpu_cores if self.tpu_cores is not None else 0 - rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores') + tpu_available = XLADeviceUtils.tpu_device_exists() + rank_zero_info(f'TPU available: {tpu_available}, using: {num_cores} TPU cores') if torch.cuda.is_available() and self._device_type != DeviceType.GPU: rank_zero_warn( diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 4640343710f81..4479687f54278 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -25,7 +25,7 @@ from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing +from pytorch_lightning.utilities import AMPType, DeviceType, parsing, XLADeviceUtils from pytorch_lightning.utilities.distributed import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach @@ -408,7 +408,7 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_ optimizer, opt_idx, train_step_and_backward_closure, - on_tpu=self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE, + on_tpu=self.trainer._device_type == DeviceType.TPU and XLADeviceUtils.tpu_device_exists(), using_native_amp=using_native_amp, using_lbfgs=is_lbfgs, ) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 28cb05bc06f2d..1a9746dc8d11d 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -50,8 +50,6 @@ from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401 from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: F401 -_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() - FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py index f20b978ebd8b6..d7e4e5fcaf240 100644 --- a/pytorch_lightning/utilities/device_parser.py +++ b/pytorch_lightning/utilities/device_parser.py @@ -15,7 +15,7 @@ import torch -from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities import XLADeviceUtils from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -109,7 +109,7 @@ def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[List[int if not _tpu_cores_valid(tpu_cores): raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]") - if tpu_cores is not None and not _TPU_AVAILABLE: + if tpu_cores is not None and not XLADeviceUtils.tpu_device_exists(): raise MisconfigurationException('No TPU devices were found.') return tpu_cores From 69021cec2695a7eea3d9737b24c235fd6f18c52d Mon Sep 17 00:00:00 2001 From: Jiasen Wu Date: Tue, 30 Mar 2021 23:07:36 +0200 Subject: [PATCH 04/12] timeout 25 => 60 --- pytorch_lightning/utilities/xla_device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index bc00f557245a7..f6c3410ff43f2 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -25,7 +25,7 @@ import torch_xla.distributed.xla_multiprocessing as xmp #: define waiting time got checking TPU available in sec -TPU_CHECK_TIMEOUT = 25 +TPU_CHECK_TIMEOUT = 60 def inner_f(index, queue, func, *args): # pragma: no cover From be31c1f8aa49eb8bff6692da24aec7a22e725a55 Mon Sep 17 00:00:00 2001 From: Jiasen Wu Date: Tue, 30 Mar 2021 23:12:49 +0200 Subject: [PATCH 05/12] index is required --- pytorch_lightning/utilities/xla_device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index f6c3410ff43f2..6b83dae3f83a3 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -62,7 +62,7 @@ class XLADeviceUtils: _TPU_AVAILABLE = False @staticmethod - def _is_device_tpu() -> bool: + def _is_device_tpu(index) -> bool: """ Check if device is TPU From 5385fb015e8ba9bbd10d50a176e29f0b3afa9c36 Mon Sep 17 00:00:00 2001 From: Jiasen Wu Date: Tue, 30 Mar 2021 23:16:40 +0200 Subject: [PATCH 06/12] resolved bad merge --- pytorch_lightning/utilities/xla_device.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index 6b83dae3f83a3..d95f8d7117b44 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -74,8 +74,7 @@ def _is_device_tpu(index) -> bool: try: device = xm.xla_device() - device_type = XLADeviceUtils._fetch_xla_device_type(device) - return device_type == "TPU" + return device.type == 'xla' # Missing XLA Configuration except RuntimeError as e: From 9fef5a9881bbaba869f3febd028253e9cbea5fa5 Mon Sep 17 00:00:00 2001 From: Jiasen Wu Date: Tue, 30 Mar 2021 23:46:05 +0200 Subject: [PATCH 07/12] timeout 60 => 120 --- pytorch_lightning/utilities/xla_device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index d95f8d7117b44..8e1be4d714f71 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -25,7 +25,7 @@ import torch_xla.distributed.xla_multiprocessing as xmp #: define waiting time got checking TPU available in sec -TPU_CHECK_TIMEOUT = 60 +TPU_CHECK_TIMEOUT = 120 def inner_f(index, queue, func, *args): # pragma: no cover From b44722bfc587571bcb77ae3cdb7d5aae4400d9c9 Mon Sep 17 00:00:00 2001 From: Jiasen Wu Date: Wed, 31 Mar 2021 09:56:26 +0200 Subject: [PATCH 08/12] replace _TPU_AVAILABLE in tests --- tests/helpers/runif.py | 4 ++-- tests/models/test_tpu.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index 5483e33d9cddb..d064493051c2e 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -29,7 +29,7 @@ _NATIVE_AMP_AVAILABLE, _RPC_AVAILABLE, _TORCH_QUANTIZE_AVAILABLE, - _TPU_AVAILABLE, + XLADeviceUtils, ) try: @@ -132,7 +132,7 @@ def __new__( reasons.append("unimplemented on Windows") if tpu: - conditions.append(not _TPU_AVAILABLE) + conditions.append(not XLADeviceUtils.tpu_device_exists()) reasons.append("TPU") if horovod: diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index b2ed0db87d8d5..5810ec64e66ab 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -25,14 +25,14 @@ from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.plugins import TPUSpawnPlugin from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities import _XLA_AVAILABLE, XLADeviceUtils from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf from tests.helpers.utils import pl_multi_process_test -if _TPU_AVAILABLE: +if _XLA_AVAILABLE: import torch_xla import torch_xla.distributed.xla_multiprocessing as xmp SERIAL_EXEC = xmp.MpSerialExecutor() @@ -256,7 +256,7 @@ def test_tpu_misconfiguration(): Trainer(tpu_cores=[1, 8]) -@pytest.mark.skipif(_TPU_AVAILABLE, reason="test requires missing TPU") +@pytest.mark.skipif(XLADeviceUtils.tpu_device_exists(), reason="test requires missing TPU") def test_exception_when_no_tpu_found(tmpdir): """Test if exception is thrown when xla devices are not available""" From 9fd8ec2eb465337f46a7d3f8265e8c836789dc5a Mon Sep 17 00:00:00 2001 From: Jiasen Wu Date: Wed, 31 Mar 2021 11:10:25 +0200 Subject: [PATCH 09/12] remove comment and unused var --- pytorch_lightning/utilities/xla_device.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index 8e1be4d714f71..6daed6990c386 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -76,8 +76,7 @@ def _is_device_tpu(index) -> bool: device = xm.xla_device() return device.type == 'xla' - # Missing XLA Configuration - except RuntimeError as e: + except RuntimeError: traceback.print_exc() return False From f804e6b13acf6a3d348fca36a0ebfaad9d99c6d9 Mon Sep 17 00:00:00 2001 From: Jiasen Wu Date: Wed, 31 Mar 2021 11:11:07 +0200 Subject: [PATCH 10/12] remove _TPU_AVAILABLE in doc --- docs/source/advanced/amp.rst | 2 +- docs/source/conf.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/advanced/amp.rst b/docs/source/advanced/amp.rst index d42f1c8c2928d..6fdb34234c08e 100644 --- a/docs/source/advanced/amp.rst +++ b/docs/source/advanced/amp.rst @@ -88,7 +88,7 @@ TPU 16-bit 16-bit on TPUs is much simpler. To use 16-bit with TPUs set precision to 16 when using the TPU flag .. testcode:: - :skipif: not _TPU_AVAILABLE + :skipif: not XLADeviceUtils.tpu_device_exists() # DEFAULT trainer = Trainer(tpu_cores=8, precision=32) diff --git a/docs/source/conf.py b/docs/source/conf.py index 1c1f3be8a636a..60f684d01081b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -385,9 +385,9 @@ def package_list_from_file(file): _NATIVE_AMP_AVAILABLE, _APEX_AVAILABLE, _XLA_AVAILABLE, - _TPU_AVAILABLE, _TORCHVISION_AVAILABLE, _module_available, + XLADeviceUtils, ) TORCHVISION_AVAILABLE = _module_available("torchvision") """ From 9cb858e4960f157d2ef4b1e544f7c4313fdae97e Mon Sep 17 00:00:00 2001 From: Jiasen Wu Date: Wed, 31 Mar 2021 16:13:37 +0200 Subject: [PATCH 11/12] revert all changes to _TPU_AVAILABLE flag --- docs/source/advanced/amp.rst | 2 +- docs/source/conf.py | 2 +- pytorch_lightning/plugins/training_type/single_tpu.py | 4 ++-- pytorch_lightning/plugins/training_type/tpu_spawn.py | 4 ++-- .../trainer/connectors/accelerator_connector.py | 2 +- pytorch_lightning/trainer/training_loop.py | 4 ++-- pytorch_lightning/utilities/__init__.py | 2 ++ pytorch_lightning/utilities/device_parser.py | 4 ++-- tests/helpers/runif.py | 4 ++-- tests/models/test_tpu.py | 6 +++--- 10 files changed, 18 insertions(+), 16 deletions(-) diff --git a/docs/source/advanced/amp.rst b/docs/source/advanced/amp.rst index 6fdb34234c08e..d42f1c8c2928d 100644 --- a/docs/source/advanced/amp.rst +++ b/docs/source/advanced/amp.rst @@ -88,7 +88,7 @@ TPU 16-bit 16-bit on TPUs is much simpler. To use 16-bit with TPUs set precision to 16 when using the TPU flag .. testcode:: - :skipif: not XLADeviceUtils.tpu_device_exists() + :skipif: not _TPU_AVAILABLE # DEFAULT trainer = Trainer(tpu_cores=8, precision=32) diff --git a/docs/source/conf.py b/docs/source/conf.py index 60f684d01081b..1c1f3be8a636a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -385,9 +385,9 @@ def package_list_from_file(file): _NATIVE_AMP_AVAILABLE, _APEX_AVAILABLE, _XLA_AVAILABLE, + _TPU_AVAILABLE, _TORCHVISION_AVAILABLE, _module_available, - XLADeviceUtils, ) TORCHVISION_AVAILABLE = _module_available("torchvision") """ diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 767b68038626a..b8d670ff16881 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -19,10 +19,10 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle -from pytorch_lightning.utilities import _XLA_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import move_data_to_device -if _XLA_AVAILABLE: +if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 73a94a1c9f577..a29310f65f724 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -23,13 +23,13 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _XLA_AVAILABLE, rank_zero_warn, _OMEGACONF_AVAILABLE +from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn, _OMEGACONF_AVAILABLE from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.apply_func import apply_to_collection -if _XLA_AVAILABLE: +if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as xla_pl import torch_xla.distributed.xla_multiprocessing as xmp diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index ec19595710835..a1c9d061b6d8a 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -54,12 +54,12 @@ _APEX_AVAILABLE, _HOROVOD_AVAILABLE, _NATIVE_AMP_AVAILABLE, + _TPU_AVAILABLE, AMPType, device_parser, DeviceType, DistributedType, rank_zero_only, - XLADeviceUtils, ) from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 4479687f54278..4640343710f81 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -25,7 +25,7 @@ from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.utilities import AMPType, DeviceType, parsing, XLADeviceUtils +from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing from pytorch_lightning.utilities.distributed import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach @@ -408,7 +408,7 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_ optimizer, opt_idx, train_step_and_backward_closure, - on_tpu=self.trainer._device_type == DeviceType.TPU and XLADeviceUtils.tpu_device_exists(), + on_tpu=self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE, using_native_amp=using_native_amp, using_lbfgs=is_lbfgs, ) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 1a9746dc8d11d..28cb05bc06f2d 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -50,6 +50,8 @@ from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401 from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: F401 +_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() + FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py index d7e4e5fcaf240..f20b978ebd8b6 100644 --- a/pytorch_lightning/utilities/device_parser.py +++ b/pytorch_lightning/utilities/device_parser.py @@ -15,7 +15,7 @@ import torch -from pytorch_lightning.utilities import XLADeviceUtils +from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -109,7 +109,7 @@ def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[List[int if not _tpu_cores_valid(tpu_cores): raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]") - if tpu_cores is not None and not XLADeviceUtils.tpu_device_exists(): + if tpu_cores is not None and not _TPU_AVAILABLE: raise MisconfigurationException('No TPU devices were found.') return tpu_cores diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index d064493051c2e..5483e33d9cddb 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -29,7 +29,7 @@ _NATIVE_AMP_AVAILABLE, _RPC_AVAILABLE, _TORCH_QUANTIZE_AVAILABLE, - XLADeviceUtils, + _TPU_AVAILABLE, ) try: @@ -132,7 +132,7 @@ def __new__( reasons.append("unimplemented on Windows") if tpu: - conditions.append(not XLADeviceUtils.tpu_device_exists()) + conditions.append(not _TPU_AVAILABLE) reasons.append("TPU") if horovod: diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 5810ec64e66ab..b2ed0db87d8d5 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -25,14 +25,14 @@ from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.plugins import TPUSpawnPlugin from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _XLA_AVAILABLE, XLADeviceUtils +from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf from tests.helpers.utils import pl_multi_process_test -if _XLA_AVAILABLE: +if _TPU_AVAILABLE: import torch_xla import torch_xla.distributed.xla_multiprocessing as xmp SERIAL_EXEC = xmp.MpSerialExecutor() @@ -256,7 +256,7 @@ def test_tpu_misconfiguration(): Trainer(tpu_cores=[1, 8]) -@pytest.mark.skipif(XLADeviceUtils.tpu_device_exists(), reason="test requires missing TPU") +@pytest.mark.skipif(_TPU_AVAILABLE, reason="test requires missing TPU") def test_exception_when_no_tpu_found(tmpdir): """Test if exception is thrown when xla devices are not available""" From 2c5a700b9a37283030b4084901bf1381b24c2644 Mon Sep 17 00:00:00 2001 From: Jiasen Wu Date: Wed, 31 Mar 2021 16:17:24 +0200 Subject: [PATCH 12/12] one more revert --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index a1c9d061b6d8a..30d2b48975a84 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -555,8 +555,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}') num_cores = self.tpu_cores if self.tpu_cores is not None else 0 - tpu_available = XLADeviceUtils.tpu_device_exists() - rank_zero_info(f'TPU available: {tpu_available}, using: {num_cores} TPU cores') + rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores') if torch.cuda.is_available() and self._device_type != DeviceType.GPU: rank_zero_warn(