-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TPU] Correct the check for TPU device in a pod environment #6755
Changes from 7 commits
fd4f471
01c274c
87ebd31
69021ce
be31c1f
5385fb0
9fef5a9
b44722b
9fd8ec2
f804e6b
9cb858e
2c5a700
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,12 +54,12 @@ | |
_APEX_AVAILABLE, | ||
_HOROVOD_AVAILABLE, | ||
_NATIVE_AMP_AVAILABLE, | ||
_TPU_AVAILABLE, | ||
AMPType, | ||
device_parser, | ||
DeviceType, | ||
DistributedType, | ||
rank_zero_only, | ||
XLADeviceUtils, | ||
) | ||
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
@@ -555,7 +555,8 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): | |
|
||
rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}') | ||
num_cores = self.tpu_cores if self.tpu_cores is not None else 0 | ||
rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores') | ||
tpu_available = XLADeviceUtils.tpu_device_exists() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can use _TPU_AVAILABLE there. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
rank_zero_info(f'TPU available: {tpu_available}, using: {num_cores} TPU cores') | ||
|
||
if torch.cuda.is_available() and self._device_type != DeviceType.GPU: | ||
rank_zero_warn( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,6 @@ | |
import os | ||
import queue as q | ||
import traceback | ||
from multiprocessing import Process, Queue | ||
|
||
import torch.multiprocessing as mp | ||
|
||
|
@@ -26,31 +25,33 @@ | |
import torch_xla.distributed.xla_multiprocessing as xmp | ||
|
||
#: define waiting time got checking TPU available in sec | ||
TPU_CHECK_TIMEOUT = 25 | ||
TPU_CHECK_TIMEOUT = 120 | ||
|
||
|
||
def inner_f(queue, func, *args, **kwargs): # pragma: no cover | ||
try: | ||
queue.put(func(*args, **kwargs)) | ||
# todo: specify the possible exception | ||
except Exception: | ||
traceback.print_exc() | ||
queue.put(None) | ||
def inner_f(index, queue, func, *args): # pragma: no cover | ||
queue.put(func(index, *args)) | ||
|
||
|
||
def pl_multi_process(func): | ||
|
||
@functools.wraps(func) | ||
def wrapper(*args, **kwargs): | ||
queue = Queue() | ||
proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs) | ||
proc.start() | ||
proc.join(TPU_CHECK_TIMEOUT) | ||
def wrapper(*args): | ||
smp = mp.get_context("spawn") | ||
queue = smp.Queue() | ||
cxt = xmp.spawn(inner_f, args=(queue, func, *args), join=False) | ||
|
||
# errors in the subprocesses are caught and saved in the error_queues | ||
# inside the context, but we don't bother to check them. | ||
if not cxt.join(TPU_CHECK_TIMEOUT): | ||
for proc in cxt.processes: | ||
if proc.is_alive(): | ||
proc.terminate() | ||
proc.join() | ||
|
||
try: | ||
return queue.get_nowait() | ||
except q.Empty: | ||
traceback.print_exc() | ||
return False | ||
return None | ||
|
||
return wrapper | ||
|
||
|
@@ -61,26 +62,24 @@ class XLADeviceUtils: | |
_TPU_AVAILABLE = False | ||
|
||
@staticmethod | ||
@pl_multi_process | ||
def _is_device_tpu() -> bool: | ||
def _is_device_tpu(index) -> bool: | ||
""" | ||
Check if device is TPU | ||
|
||
Return: | ||
A boolean value indicating if the xla device is a TPU device or not | ||
""" | ||
if not _XLA_AVAILABLE: | ||
return False | ||
|
||
def _fn(_: int, mp_queue): | ||
try: | ||
device = xm.xla_device() | ||
mp_queue.put(device.type == 'xla') | ||
except Exception: | ||
mp_queue.put(False) | ||
try: | ||
device = xm.xla_device() | ||
return device.type == 'xla' | ||
|
||
smp = mp.get_context("spawn") | ||
queue = smp.SimpleQueue() | ||
xmp.spawn(_fn, args=(queue, ), nprocs=1) | ||
return queue.get() | ||
# Missing XLA Configuration | ||
except RuntimeError as e: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comments can be removed if not used There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay |
||
traceback.print_exc() | ||
return False | ||
|
||
@staticmethod | ||
def xla_available() -> bool: | ||
|
@@ -105,7 +104,7 @@ def tpu_device_exists() -> bool: | |
|
||
if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE: | ||
|
||
XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu() | ||
XLADeviceUtils._TPU_AVAILABLE = bool(pl_multi_process(XLADeviceUtils._is_device_tpu)()) | ||
|
||
if XLADeviceUtils._TPU_AVAILABLE: | ||
os.environ["PL_TPU_AVAILABLE"] = '1' | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you remove
_TPU_AVAILABLE
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I replaced
TPU_AVAILABLE
flag with a callXLADeviceUtils.tpu_device_exists()
, because now it spawns to decide the value, and I see python complaints (something not frozen yet) to spawn at the time of module creation.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please, add back _TPU_AVAILABLE. We use this as a common imports API in Lightning and it won't be backward compatible.