Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] Add support for PJRT from PyTorch/XLA 2.0 #17352

Merged
merged 56 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
848a559
PJRT related fixes
Liyang90 Mar 20, 2023
50873e7
4 devices on V4
Liyang90 Mar 20, 2023
e00c7a1
minor update
Liyang90 Mar 21, 2023
f07a556
fix for PJRT multithreading
Liyang90 Apr 11, 2023
9c493c6
Merge branch 'master' into xla_pjrt_debug
Liyang90 Apr 11, 2023
f2bc18d
update auto_device_count() for PJRT
Liyang90 Apr 12, 2023
b729add
minor update for PJRT
Liyang90 Apr 12, 2023
51ab4de
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
d882417
Apply suggestions from code review
Liyang90 Apr 12, 2023
acb41a8
Avoid warning. Comment. Sort import
carmocca Apr 12, 2023
b3f3e28
Mock experimental
carmocca Apr 13, 2023
9660205
Explicit nprocs
carmocca Apr 13, 2023
918f51c
Merge branch 'master' into xla_pjrt_debug
carmocca Apr 13, 2023
3ea91e5
XLAStrategy broadcast master param
Liyang90 Apr 13, 2023
6eebab0
Different queue with XRT vs PJRT
carmocca Apr 13, 2023
9dc6e33
Update CHANGELOG
carmocca Apr 13, 2023
32e133c
Fix mypy
carmocca Apr 13, 2023
fbd5e46
Port current changes to Fabric
carmocca Apr 14, 2023
bf93f66
Parametrize CPU test that checks for the device offset
carmocca Apr 14, 2023
cda79fe
Manager().Queue() is not the same as mp.Queue()
carmocca Apr 14, 2023
08c2c20
Move outside of FIT check, remove PJRT check
carmocca Apr 14, 2023
659a9c5
deepcopy objects together to preserve the inference relation
Liyang90 Apr 14, 2023
6ecc718
Include function with Fabric too
carmocca Apr 14, 2023
63ede7b
Merge branch 'master' into xla_pjrt_debug
carmocca Apr 14, 2023
58602a6
Uncomment checkgroup skips
carmocca Apr 14, 2023
c3495c3
Merge branch 'master' into xla_pjrt_debug
carmocca Apr 14, 2023
57ed34e
Avoid pickle errors created by local test classes
carmocca Apr 14, 2023
867fc85
Fix various tests
carmocca Apr 14, 2023
ecf3be3
Test to FIXME
carmocca Apr 14, 2023
87eb061
Port changes from #17381
carmocca Apr 14, 2023
0442dc0
Missed this in the commit above
carmocca Apr 14, 2023
745d64d
Merge branch 'master' into xla_pjrt_debug
Borda Apr 14, 2023
eac8a61
Update
carmocca Apr 15, 2023
3dc083b
Update src/lightning/fabric/accelerators/tpu.py
carmocca Apr 16, 2023
343c31b
XFAIL TimeoutError
carmocca Apr 16, 2023
64c8880
Revert "XFAIL TimeoutError"
carmocca Apr 16, 2023
300626b
DEBUG
carmocca Apr 16, 2023
b633619
Merge branch 'master' into xla_pjrt_debug
carmocca Apr 17, 2023
2594088
XLAEnvironment main_address main_port raise NotImplementedError
Liyang90 Apr 18, 2023
1afc27e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2023
4e64b0d
Fix env var test
carmocca Apr 18, 2023
a18f849
Merge branch 'master' into xla_pjrt_debug
carmocca Apr 18, 2023
e10b702
Still skip XRT
carmocca Apr 18, 2023
79b3113
Revert "DEBUG"
carmocca Apr 18, 2023
b6c32f8
Merge branch 'master' into carmocca/available-check
carmocca Apr 18, 2023
a83d039
Accidental removal
carmocca Apr 18, 2023
7a7a791
pickle issue
carmocca Apr 18, 2023
49ec95c
windows you weird
carmocca Apr 18, 2023
49ea06a
Debug
carmocca Apr 18, 2023
db99898
Fix
carmocca Apr 18, 2023
cfeb486
Merge branch 'master' into xla_pjrt_debug
carmocca Apr 18, 2023
a449a55
Merge branch 'carmocca/available-check' into xla_pjrt_debug
carmocca Apr 18, 2023
e57d5cf
Merge branch 'carmocca/debug-master-failure' into xla_pjrt_debug
carmocca Apr 18, 2023
5c2a72c
Update tests/tests_fabric/plugins/environments/test_xla.py
carmocca Apr 18, 2023
3ebb91e
Merge branch 'master' into xla_pjrt_debug
carmocca Apr 18, 2023
4fe6ea4
Merge branch 'master' into xla_pjrt_debug
carmocca Apr 18, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions .github/checkgroup.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,15 @@ subprojects:
# checks:
# - "pytorch-lightning (IPUs)"

# TODO: there are issues to address
#- id: "pytorch-lightning: TPU workflow"
# paths:
- id: "pytorch-lightning: TPU workflow"
paths:
# tpu CI availability is very limited, so we only require tpu tests
# to pass when their configurations are modified
# - ".github/workflows/tpu-tests.yml"
# - "tests/tests_pytorch/run_tpu_tests.sh"
# checks:
# - "test-on-tpus (pytorch, xrt)"
# - "test-on-tpus (pytorch, pjrt)"
- ".github/workflows/tpu-tests.yml"
- "tests/tests_pytorch/run_tpu_tests.sh"
checks:
#- "test-on-tpus (pytorch, xrt)"
- "test-on-tpus (pytorch, pjrt)"

- id: "fabric: Docs"
paths:
Expand Down Expand Up @@ -238,8 +237,7 @@ subprojects:
- "tests/tests_fabric/run_tpu_tests.sh"
checks:
- "test-on-tpus (fabric, xrt)"
# TODO: uncomment when PJRT support is added
#- "test-on-tpus (pytorch, pjrt)"
- "test-on-tpus (pytorch, pjrt)"

# SECTION: lightning_app

Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/ci-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,6 @@ jobs:
- name: Prevent using raw source
run: rm -rf src/

- name: Prevent using raw source
run: rm -rf src/

- name: Testing Warnings
working-directory: tests/tests_pytorch
# needs to run outside of `pytest`
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for the TPU-v4 architecture ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))


- Added support for XLA's new PJRT runtime ([#17352](https://github.com/Lightning-AI/lightning/pull/17352))


- Check for invalid TPU device inputs ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))


Expand Down
80 changes: 42 additions & 38 deletions src/lightning/fabric/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import queue as q
import traceback
from multiprocessing import Process, Queue
from typing import Any, Callable, Dict, List, Union

Expand Down Expand Up @@ -49,14 +47,21 @@ def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, List[int]]:
@staticmethod
def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""Gets parallel devices for the Accelerator."""
from torch_xla.experimental import pjrt

devices = _parse_tpu_devices(devices)
# In XLA index 0 maps to CPU, in fact, a `xla_device()` with no arguments has index 1
# since the user passes a 0-based index, we need to adjust the indices
if pjrt.using_pjrt():
device_offset = 0
carmocca marked this conversation as resolved.
Show resolved Hide resolved
else:
# In XLA XRT index 0 maps to CPU, in fact, a `xla_device()` with no arguments has index 1
# since the user passes a 0-based index, we need to adjust the indices
device_offset = 1

if isinstance(devices, int):
return [torch.device("xla", i) for i in range(1, devices + 1)]
return [torch.device("xla", i) for i in range(device_offset, devices + device_offset)]
else:
# list of devices is not supported, just a specific index, fine to access [0]
return [torch.device("xla", devices[0] + 1)]
return [torch.device("xla", devices[0] + device_offset)]
# we cannot create `xla_device` here because processes have not been spawned yet (this is called in the
# accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`.
# it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy
Expand All @@ -68,15 +73,33 @@ def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]:
def auto_device_count() -> int:
"""Get the devices when set to auto."""
import torch_xla.core.xla_env_vars as xenv
from torch_xla.experimental import pjrt, tpu
from torch_xla.utils.utils import getenv_as

return getenv_as(xenv.TPU_NUM_DEVICES, int, 8)
if pjrt.using_pjrt():
device_count_on_version = {2: 8, 3: 8, 4: 4}
return device_count_on_version.get(tpu.version(), 8)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
else:
return getenv_as(xenv.TPU_NUM_DEVICES, int, 8)

@staticmethod
@functools.lru_cache(maxsize=1)
def is_available() -> bool:
# check `_XLA_AVAILABLE` again to avoid launching processes
return bool(_XLA_AVAILABLE) and _is_device_tpu()
if not _XLA_AVAILABLE:
return False
queue: Queue = Queue()
proc = Process(target=_inner_f, args=(queue, _has_tpu_device))
proc.start()
proc.join(TPU_CHECK_TIMEOUT)
if proc.is_alive():
proc.terminate()
proc.join()
# if the timeout is triggered, fail to avoid silently running on a different accelerator
raise TimeoutError(
"Timed out waiting to check whether a TPU is available. You can increase the TPU_CHECK_TIMEOUT value."
f" Currently {TPU_CHECK_TIMEOUT}"
)
return queue.get_nowait()

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
Expand All @@ -91,45 +114,26 @@ def register_accelerators(cls, accelerator_registry: Dict) -> None:
TPU_CHECK_TIMEOUT = 60


def _inner_f(queue: Queue, func: Callable, *args: Any, **kwargs: Any) -> None: # pragma: no cover
try:
queue.put(func(*args, **kwargs))
except Exception:
traceback.print_exc()
queue.put(None)


def _multi_process(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Union[bool, Any]:
queue: Queue = Queue()
proc = Process(target=_inner_f, args=(queue, func, *args), kwargs=kwargs)
proc.start()
proc.join(TPU_CHECK_TIMEOUT)
try:
return queue.get_nowait()
except q.Empty:
traceback.print_exc()
return False

return wrapper
def _inner_f(queue: Queue, func: Callable) -> None:
res = func()
queue.put(res)


@_multi_process
def _is_device_tpu() -> bool:
"""Check if TPU devices are available. Runs XLA device check within a separate process.
def _has_tpu_device() -> bool:
"""Check if TPU devices are available.

Return:
A boolean value indicating if TPU devices are available
"""
if not _XLA_AVAILABLE:
return False
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt

# For the TPU Pod training process, for example, if we have
# TPU v3-32 with 4 VMs, the world size would be 4 and as
# we would have to use `torch_xla.distributed.xla_dist` for
# multiple VMs and TPU_CONFIG won't be available, running
if pjrt.using_pjrt():
return bool(xm.get_xla_supported_devices("TPU"))
# For the TPU Pod training process, for example, if we have TPU v3-32 with 4 VMs, the world size would be 4 and as
# we would have to use `torch_xla.distributed.xla_dist` for multiple VMs and TPU_CONFIG won't be available, running
# `xm.get_xla_supported_devices("TPU")` won't be possible.
return (xm.xrt_world_size() > 1) or bool(xm.get_xla_supported_devices("TPU"))

Expand Down
11 changes: 4 additions & 7 deletions src/lightning/fabric/plugins/environments/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from typing import Any

from lightning.fabric.accelerators.tpu import _XLA_AVAILABLE, TPUAccelerator
Expand All @@ -39,15 +38,13 @@ def creates_processes_externally(self) -> bool:

@property
def main_address(self) -> str:
import torch_xla.core.xla_env_vars as xenv

return os.environ[xenv.TPU_MESH_CTLER_ADDR]
# unused by lightning
raise NotImplementedError

@property
def main_port(self) -> int:
import torch_xla.core.xla_env_vars as xenv

return int(os.environ[xenv.TPU_MESH_CTLER_PORT])
# unused by lightning
raise NotImplementedError

@staticmethod
def detect() -> bool:
Expand Down
39 changes: 32 additions & 7 deletions src/lightning/fabric/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import queue
import time
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing import Any, Callable, Optional, TYPE_CHECKING, Union

from torch.multiprocessing import get_context
import torch.multiprocessing as mp

from lightning.fabric.accelerators.tpu import _XLA_AVAILABLE
from lightning.fabric.strategies.launchers.launcher import _Launcher
Expand Down Expand Up @@ -63,15 +63,30 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
*args: Optional positional arguments to be passed to the given function.
**kwargs: Optional keyword arguments to be passed to the given function.
"""
context = get_context(self._start_method)
return_queue = context.SimpleQueue()
from torch_xla.experimental import pjrt

using_pjrt = pjrt.using_pjrt()
return_queue: Union[queue.Queue, mp.SimpleQueue]
if using_pjrt:
# pjrt requires that the queue is serializable
return_queue = mp.Manager().Queue()
else:
return_queue = mp.get_context(self._start_method).SimpleQueue()

import torch_xla.distributed.xla_multiprocessing as xmp

spawn_kwargs = {}
nprocs = self._strategy.num_processes
if not using_pjrt or nprocs == 1:
# avoid warning: "Unsupported nprocs". If it's 1, it will call the launched function directly.
# otherwise it will use all devices
spawn_kwargs["nprocs"] = nprocs

xmp.spawn(
self._wrapping_function,
args=(function, args, kwargs, return_queue),
nprocs=self._strategy.num_processes,
start_method=self._start_method,
**spawn_kwargs,
)
return return_queue.get()

Expand All @@ -83,9 +98,19 @@ def _wrapping_function(
function: Callable,
args: Any,
kwargs: Any,
return_queue: SimpleQueue,
return_queue: Union[mp.SimpleQueue, queue.Queue],
global_states: Optional[_GlobalStateSnapshot] = None,
) -> None:
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt

if pjrt.using_pjrt() and len(xm.get_xla_supported_devices()) > 1:
# `get_xla_supported_devices` in the spawned process returns the logical devices (2 for v2/v3 and 1 for v4)
# so when there's more than one (multithreading), objects need to be deep-copied
import copy

function, args, kwargs = copy.deepcopy((function, args, kwargs))

results = function(*args, **kwargs)

if self._strategy.local_rank == 0:
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/fabric/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def setup_environment(self) -> None:
super().setup_environment()

def setup_module(self, module: Module) -> Module:
from torch_xla.experimental import pjrt

pjrt.broadcast_master_param(module)
return module

def module_to_device(self, module: Module) -> None:
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for the TPU-v4 architecture ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))

-
- Added support for XLA's new PJRT runtime ([#17352](https://github.com/Lightning-AI/lightning/pull/17352))


- Check for invalid TPU device inputs ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.
import logging
import os
import queue
import tempfile
from contextlib import suppress
from dataclasses import dataclass
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional
from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -138,7 +138,7 @@ def _wrapping_function(
function: Callable,
args: Any,
kwargs: Any,
return_queue: SimpleQueue,
return_queue: Union[mp.SimpleQueue, queue.Queue],
global_states: Optional["_GlobalStateSnapshot"] = None,
) -> None:
if global_states:
Expand Down
37 changes: 31 additions & 6 deletions src/lightning/pytorch/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, Optional
import queue
from typing import Any, Callable, Optional, Union

import torch.multiprocessing as mp

Expand Down Expand Up @@ -68,16 +68,31 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
a selected set of attributes get restored in the main process after processes join.
**kwargs: Optional keyword arguments to be passed to the given function.
"""
context = mp.get_context(self._start_method)
return_queue = context.SimpleQueue()
carmocca marked this conversation as resolved.
Show resolved Hide resolved
from torch_xla.experimental import pjrt

using_pjrt = pjrt.using_pjrt()
return_queue: Union[queue.Queue, mp.SimpleQueue]
if using_pjrt:
# pjrt requires that the queue is serializable
return_queue = mp.Manager().Queue()
else:
return_queue = mp.get_context(self._start_method).SimpleQueue()

import torch_xla.distributed.xla_multiprocessing as xmp

spawn_kwargs = {}
nprocs = self._strategy.num_processes
if not using_pjrt or nprocs == 1:
# avoid warning: "Unsupported nprocs". If it's 1, it will call the launched function directly.
# otherwise it will use all devices
spawn_kwargs["nprocs"] = nprocs

process_context = xmp.spawn(
self._wrapping_function,
args=(trainer, function, args, kwargs, return_queue),
nprocs=self._strategy.num_processes,
start_method=self._start_method,
join=False, # we will join ourselves to get the process references
**spawn_kwargs,
)
# xla will not actually create processes if only 1 device
if process_context is not None:
Expand All @@ -101,9 +116,19 @@ def _wrapping_function(
function: Callable,
args: Any,
kwargs: Any,
return_queue: SimpleQueue,
return_queue: Union[mp.SimpleQueue, queue.Queue],
global_states: Optional[_GlobalStateSnapshot] = None,
) -> None:
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt

if pjrt.using_pjrt() and len(xm.get_xla_supported_devices()) > 1:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# `get_xla_supported_devices` in the spawned process returns the logical devices (2 for v2/v3 and 1 for v4)
# so when there's more than one (multithreading), objects need to be deep-copied
import copy

trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs))

results = function(*args, **kwargs)

if trainer is not None:
Expand Down
Loading