diff --git a/.github/checkgroup.yml b/.github/checkgroup.yml index f9c2bade644c3..d39e8e2f23526 100644 --- a/.github/checkgroup.yml +++ b/.github/checkgroup.yml @@ -104,16 +104,15 @@ subprojects: # checks: # - "pytorch-lightning (IPUs)" - # TODO: there are issues to address - #- id: "pytorch-lightning: TPU workflow" - # paths: + - id: "pytorch-lightning: TPU workflow" + paths: # tpu CI availability is very limited, so we only require tpu tests # to pass when their configurations are modified - # - ".github/workflows/tpu-tests.yml" - # - "tests/tests_pytorch/run_tpu_tests.sh" - # checks: - # - "test-on-tpus (pytorch, xrt)" - # - "test-on-tpus (pytorch, pjrt)" + - ".github/workflows/tpu-tests.yml" + - "tests/tests_pytorch/run_tpu_tests.sh" + checks: + #- "test-on-tpus (pytorch, xrt)" + - "test-on-tpus (pytorch, pjrt)" - id: "fabric: Docs" paths: @@ -238,8 +237,7 @@ subprojects: - "tests/tests_fabric/run_tpu_tests.sh" checks: - "test-on-tpus (fabric, xrt)" - # TODO: uncomment when PJRT support is added - #- "test-on-tpus (pytorch, pjrt)" + - "test-on-tpus (pytorch, pjrt)" # SECTION: lightning_app diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 759a8bad38db7..50d4d7e51eb2b 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for the TPU-v4 architecture ([#17227](https://github.com/Lightning-AI/lightning/pull/17227)) +- Added support for XLA's new PJRT runtime ([#17352](https://github.com/Lightning-AI/lightning/pull/17352)) + + - Check for invalid TPU device inputs ([#17227](https://github.com/Lightning-AI/lightning/pull/17227)) diff --git a/src/lightning/fabric/accelerators/tpu.py b/src/lightning/fabric/accelerators/tpu.py index 2593fc1c80ec3..e37ed69d42f14 100644 --- a/src/lightning/fabric/accelerators/tpu.py +++ b/src/lightning/fabric/accelerators/tpu.py @@ -47,14 +47,21 @@ def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, List[int]]: @staticmethod def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" + from torch_xla.experimental import pjrt + devices = _parse_tpu_devices(devices) - # In XLA index 0 maps to CPU, in fact, a `xla_device()` with no arguments has index 1 - # since the user passes a 0-based index, we need to adjust the indices + if pjrt.using_pjrt(): + device_offset = 0 + else: + # In XLA XRT index 0 maps to CPU, in fact, a `xla_device()` with no arguments has index 1 + # since the user passes a 0-based index, we need to adjust the indices + device_offset = 1 + if isinstance(devices, int): - return [torch.device("xla", i) for i in range(1, devices + 1)] + return [torch.device("xla", i) for i in range(device_offset, devices + device_offset)] else: # list of devices is not supported, just a specific index, fine to access [0] - return [torch.device("xla", devices[0] + 1)] + return [torch.device("xla", devices[0] + device_offset)] # we cannot create `xla_device` here because processes have not been spawned yet (this is called in the # accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`. # it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy @@ -66,9 +73,14 @@ def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]: def auto_device_count() -> int: """Get the devices when set to auto.""" import torch_xla.core.xla_env_vars as xenv + from torch_xla.experimental import pjrt, tpu from torch_xla.utils.utils import getenv_as - return getenv_as(xenv.TPU_NUM_DEVICES, int, 8) + if pjrt.using_pjrt(): + device_count_on_version = {2: 8, 3: 8, 4: 4} + return device_count_on_version.get(tpu.version(), 8) + else: + return getenv_as(xenv.TPU_NUM_DEVICES, int, 8) @staticmethod @functools.lru_cache(maxsize=1) diff --git a/src/lightning/fabric/plugins/environments/xla.py b/src/lightning/fabric/plugins/environments/xla.py index 658ff19d26c1d..eeb7784b52194 100644 --- a/src/lightning/fabric/plugins/environments/xla.py +++ b/src/lightning/fabric/plugins/environments/xla.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import os from typing import Any from lightning.fabric.accelerators.tpu import _XLA_AVAILABLE, TPUAccelerator @@ -39,15 +38,13 @@ def creates_processes_externally(self) -> bool: @property def main_address(self) -> str: - import torch_xla.core.xla_env_vars as xenv - - return os.environ[xenv.TPU_MESH_CTLER_ADDR] + # unused by lightning + raise NotImplementedError @property def main_port(self) -> int: - import torch_xla.core.xla_env_vars as xenv - - return int(os.environ[xenv.TPU_MESH_CTLER_PORT]) + # unused by lightning + raise NotImplementedError @staticmethod def detect() -> bool: diff --git a/src/lightning/fabric/strategies/launchers/xla.py b/src/lightning/fabric/strategies/launchers/xla.py index 6911c4014ccd8..38667e4b445c8 100644 --- a/src/lightning/fabric/strategies/launchers/xla.py +++ b/src/lightning/fabric/strategies/launchers/xla.py @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import queue import time -from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Optional, TYPE_CHECKING +from typing import Any, Callable, Optional, TYPE_CHECKING, Union -from torch.multiprocessing import get_context +import torch.multiprocessing as mp from lightning.fabric.accelerators.tpu import _XLA_AVAILABLE from lightning.fabric.strategies.launchers.launcher import _Launcher @@ -63,15 +63,30 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: *args: Optional positional arguments to be passed to the given function. **kwargs: Optional keyword arguments to be passed to the given function. """ - context = get_context(self._start_method) - return_queue = context.SimpleQueue() + from torch_xla.experimental import pjrt + + using_pjrt = pjrt.using_pjrt() + return_queue: Union[queue.Queue, mp.SimpleQueue] + if using_pjrt: + # pjrt requires that the queue is serializable + return_queue = mp.Manager().Queue() + else: + return_queue = mp.get_context(self._start_method).SimpleQueue() + import torch_xla.distributed.xla_multiprocessing as xmp + spawn_kwargs = {} + nprocs = self._strategy.num_processes + if not using_pjrt or nprocs == 1: + # avoid warning: "Unsupported nprocs". If it's 1, it will call the launched function directly. + # otherwise it will use all devices + spawn_kwargs["nprocs"] = nprocs + xmp.spawn( self._wrapping_function, args=(function, args, kwargs, return_queue), - nprocs=self._strategy.num_processes, start_method=self._start_method, + **spawn_kwargs, ) return return_queue.get() @@ -83,9 +98,19 @@ def _wrapping_function( function: Callable, args: Any, kwargs: Any, - return_queue: SimpleQueue, + return_queue: Union[mp.SimpleQueue, queue.Queue], global_states: Optional[_GlobalStateSnapshot] = None, ) -> None: + import torch_xla.core.xla_model as xm + from torch_xla.experimental import pjrt + + if pjrt.using_pjrt() and len(xm.get_xla_supported_devices()) > 1: + # `get_xla_supported_devices` in the spawned process returns the logical devices (2 for v2/v3 and 1 for v4) + # so when there's more than one (multithreading), objects need to be deep-copied + import copy + + function, args, kwargs = copy.deepcopy((function, args, kwargs)) + results = function(*args, **kwargs) if self._strategy.local_rank == 0: diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py index 4b3d88d7b17c8..26499280431ed 100644 --- a/src/lightning/fabric/strategies/xla.py +++ b/src/lightning/fabric/strategies/xla.py @@ -89,6 +89,9 @@ def setup_environment(self) -> None: super().setup_environment() def setup_module(self, module: Module) -> Module: + from torch_xla.experimental import pjrt + + pjrt.broadcast_master_param(module) return module def module_to_device(self, module: Module) -> None: diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index a407d5b563ca1..e80a8b31c4ae8 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -17,6 +17,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for the TPU-v4 architecture ([#17227](https://github.com/Lightning-AI/lightning/pull/17227)) +- +- Added support for XLA's new PJRT runtime ([#17352](https://github.com/Lightning-AI/lightning/pull/17352)) + - Check for invalid TPU device inputs ([#17227](https://github.com/Lightning-AI/lightning/pull/17227)) diff --git a/src/lightning/pytorch/strategies/launchers/multiprocessing.py b/src/lightning/pytorch/strategies/launchers/multiprocessing.py index 60a0f7195f7fc..bb8e5a382ad70 100644 --- a/src/lightning/pytorch/strategies/launchers/multiprocessing.py +++ b/src/lightning/pytorch/strategies/launchers/multiprocessing.py @@ -13,11 +13,11 @@ # limitations under the License. import logging import os +import queue import tempfile from contextlib import suppress from dataclasses import dataclass -from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional +from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Union import numpy as np import torch @@ -138,7 +138,7 @@ def _wrapping_function( function: Callable, args: Any, kwargs: Any, - return_queue: SimpleQueue, + return_queue: Union[mp.SimpleQueue, queue.Queue], global_states: Optional["_GlobalStateSnapshot"] = None, ) -> None: if global_states: diff --git a/src/lightning/pytorch/strategies/launchers/xla.py b/src/lightning/pytorch/strategies/launchers/xla.py index 823e5671c3b3e..af021abcb2c2b 100644 --- a/src/lightning/pytorch/strategies/launchers/xla.py +++ b/src/lightning/pytorch/strategies/launchers/xla.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Optional +import queue +from typing import Any, Callable, Optional, Union import torch.multiprocessing as mp @@ -68,16 +68,31 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] a selected set of attributes get restored in the main process after processes join. **kwargs: Optional keyword arguments to be passed to the given function. """ - context = mp.get_context(self._start_method) - return_queue = context.SimpleQueue() + from torch_xla.experimental import pjrt + + using_pjrt = pjrt.using_pjrt() + return_queue: Union[queue.Queue, mp.SimpleQueue] + if using_pjrt: + # pjrt requires that the queue is serializable + return_queue = mp.Manager().Queue() + else: + return_queue = mp.get_context(self._start_method).SimpleQueue() + import torch_xla.distributed.xla_multiprocessing as xmp + spawn_kwargs = {} + nprocs = self._strategy.num_processes + if not using_pjrt or nprocs == 1: + # avoid warning: "Unsupported nprocs". If it's 1, it will call the launched function directly. + # otherwise it will use all devices + spawn_kwargs["nprocs"] = nprocs + process_context = xmp.spawn( self._wrapping_function, args=(trainer, function, args, kwargs, return_queue), - nprocs=self._strategy.num_processes, start_method=self._start_method, join=False, # we will join ourselves to get the process references + **spawn_kwargs, ) # xla will not actually create processes if only 1 device if process_context is not None: @@ -101,9 +116,19 @@ def _wrapping_function( function: Callable, args: Any, kwargs: Any, - return_queue: SimpleQueue, + return_queue: Union[mp.SimpleQueue, queue.Queue], global_states: Optional[_GlobalStateSnapshot] = None, ) -> None: + import torch_xla.core.xla_model as xm + from torch_xla.experimental import pjrt + + if pjrt.using_pjrt() and len(xm.get_xla_supported_devices()) > 1: + # `get_xla_supported_devices` in the spawned process returns the logical devices (2 for v2/v3 and 1 for v4) + # so when there's more than one (multithreading), objects need to be deep-copied + import copy + + trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs)) + results = function(*args, **kwargs) if trainer is not None: diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py index fb3eb92458774..6654b6b015881 100644 --- a/src/lightning/pytorch/strategies/xla.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -92,14 +92,9 @@ def root_device(self) -> torch.device: return xm.xla_device() - @property - def local_rank(self) -> int: - return self.cluster_environment.local_rank() if self.cluster_environment is not None else 0 - def connect(self, model: "pl.LightningModule") -> None: - import torch_xla.distributed.xla_multiprocessing as xmp - - self.wrapped_model = xmp.MpModelWrapper(_LightningModuleWrapperBase(model)) + # this is called in the spawned process, so no need to use `xmp.MpModelWrapper` + self.wrapped_model = _LightningModuleWrapperBase(model) return super().connect(model) def _configure_launcher(self) -> None: @@ -119,6 +114,10 @@ def setup(self, trainer: "pl.Trainer") -> None: set_shared_parameters(self.lightning_module, shared_params) self.setup_precision_plugin() + from torch_xla.experimental import pjrt + + pjrt.broadcast_master_param(self.model) + if trainer.state.fn == TrainerFn.FITTING: self.setup_optimizers(trainer) _optimizers_to_device(self.optimizers, self.root_device) diff --git a/tests/tests_fabric/accelerators/test_tpu.py b/tests/tests_fabric/accelerators/test_tpu.py index c0ab04a82f269..bee26df79bf68 100644 --- a/tests/tests_fabric/accelerators/test_tpu.py +++ b/tests/tests_fabric/accelerators/test_tpu.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License -import os import pytest @@ -21,7 +20,9 @@ @RunIf(tpu=True) def test_auto_device_count(): - assert TPUAccelerator.auto_device_count() == int(os.environ["TPU_NUM_DEVICES"]) + # this depends on the chip used, e.g. with v4-8 we expect 4 + # there's no easy way to test it without copying the `auto_device_count` so just check that its greater than 1 + assert TPUAccelerator.auto_device_count() > 1 @RunIf(tpu=True) diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 51fe6961de2eb..3256ed9e0d62b 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -104,6 +104,7 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N monkeypatch.setattr(lightning.fabric.accelerators.tpu.TPUAccelerator, "auto_device_count", lambda *_: 8) monkeypatch.setitem(sys.modules, "torch_xla", Mock()) monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock()) @pytest.fixture(scope="function") diff --git a/tests/tests_fabric/plugins/environments/test_xla.py b/tests/tests_fabric/plugins/environments/test_xla.py index b94efed1579ef..3c80498031977 100644 --- a/tests/tests_fabric/plugins/environments/test_xla.py +++ b/tests/tests_fabric/plugins/environments/test_xla.py @@ -35,33 +35,38 @@ def test_default_attributes(*_): assert env.local_rank() == 0 assert env.node_rank() == 0 - with pytest.raises(KeyError): - # main_address is required to be passed as env variable + with pytest.raises(NotImplementedError): _ = env.main_address - with pytest.raises(KeyError): - # main_port is required to be passed as env variable + with pytest.raises(NotImplementedError): _ = env.main_port @RunIf(tpu=True) -@mock.patch.dict( - os.environ, - { - **os.environ, - "TPU_MESH_CONTROLLER_ADDRESS": "1.2.3.4", - "TPU_MESH_CONTROLLER_PORT": "500", - "XRT_SHARD_WORLD_SIZE": "1", - "XRT_SHARD_ORDINAL": "0", - "XRT_SHARD_LOCAL_ORDINAL": "2", - "XRT_HOST_ORDINAL": "3", - }, - clear=True, -) -def test_attributes_from_environment_variables(): +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) +def test_attributes_from_environment_variables(monkeypatch): """Test that the default cluster environment takes the attributes from the environment variables.""" + from torch_xla.experimental import pjrt + + os.environ["XRT_HOST_ORDINAL"] = "3" + if not pjrt.using_pjrt(): + os.environ.update( + { + "XRT_SHARD_WORLD_SIZE": "1", + "XRT_SHARD_ORDINAL": "0", + "XRT_SHARD_LOCAL_ORDINAL": "2", + } + ) + else: + # PJRT doesn't pull these from envvars + monkeypatch.setattr(pjrt, "world_size", lambda: 1) + monkeypatch.setattr(pjrt, "global_ordinal", lambda: 0) + monkeypatch.setattr(pjrt, "local_ordinal", lambda: 2) + env = XLAEnvironment() - assert env.main_address == "1.2.3.4" - assert env.main_port == 500 + with pytest.raises(NotImplementedError): + _ = env.main_address + with pytest.raises(NotImplementedError): + _ = env.main_port assert env.world_size() == 1 assert env.global_rank() == 0 assert env.local_rank() == 2 diff --git a/tests/tests_pytorch/accelerators/test_tpu.py b/tests/tests_pytorch/accelerators/test_tpu.py index 98555e8d34a3a..956b51ce6b3e7 100644 --- a/tests/tests_pytorch/accelerators/test_tpu.py +++ b/tests/tests_pytorch/accelerators/test_tpu.py @@ -99,60 +99,61 @@ def test_accelerator_tpu(accelerator, devices, tpu_available): assert isinstance(trainer.strategy, XLAStrategy) +class ManualOptimizationModel(BoringModel): + + count = 0 + called = collections.defaultdict(int) + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + @property + def should_update(self): + return self.count % 2 == 0 + + def on_train_batch_start(self, batch, batch_idx): + self.called["on_train_batch_start"] += 1 + self.weight_before = self.layer.weight.clone() + + def training_step(self, batch, batch_idx): + self.called["training_step"] += 1 + opt = self.optimizers() + loss = self.step(batch) + + if self.should_update: + self.manual_backward(loss) + opt.step() + opt.zero_grad() + return loss + + def on_train_batch_end(self, *_): + self.called["on_train_batch_end"] += 1 + after_before = self.layer.weight.clone() + if self.should_update: + assert not torch.equal(self.weight_before, after_before), self.count + else: + assert torch.equal(self.weight_before, after_before) + assert_emtpy_grad(self.layer.weight.grad) + self.count += 1 + + def on_train_start(self): + opt = self.optimizers() + self.opt_step_patch = patch.object(opt, "step", wraps=opt.step) + self.opt_step_mock = self.opt_step_patch.start() + + def on_train_end(self): + assert self.called["training_step"] == 5 + assert self.called["on_train_batch_start"] == 5 + assert self.called["on_train_batch_end"] == 5 + + self.opt_step_patch.stop() + assert self.opt_step_mock.call_count == 3 + + @RunIf(tpu=True) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_manual_optimization_tpus(tmpdir): - class ManualOptimizationModel(BoringModel): - - count = 0 - called = collections.defaultdict(int) - - def __init__(self): - super().__init__() - self.automatic_optimization = False - - @property - def should_update(self): - return self.count % 2 == 0 - - def on_train_batch_start(self, batch, batch_idx): - self.called["on_train_batch_start"] += 1 - self.weight_before = self.layer.weight.clone() - - def training_step(self, batch, batch_idx): - self.called["training_step"] += 1 - opt = self.optimizers() - loss = self.step(batch) - - if self.should_update: - self.manual_backward(loss) - opt.step() - opt.zero_grad() - return loss - - def on_train_batch_end(self, *_): - self.called["on_train_batch_end"] += 1 - after_before = self.layer.weight.clone() - if self.should_update: - assert not torch.equal(self.weight_before, after_before), self.count - else: - assert torch.equal(self.weight_before, after_before) - assert_emtpy_grad(self.layer.weight.grad) - self.count += 1 - - def on_train_start(self): - opt = self.optimizers() - self.opt_step_patch = patch.object(opt, "step", wraps=opt.step) - self.opt_step_mock = self.opt_step_patch.start() - - def on_train_end(self): - assert self.called["training_step"] == 5 - assert self.called["on_train_batch_start"] == 5 - assert self.called["on_train_batch_end"] == 5 - - self.opt_step_patch.stop() - assert self.opt_step_mock.call_count == 3 - model = ManualOptimizationModel() model_copy = deepcopy(model) @@ -204,33 +205,34 @@ def test_auto_parameters_tying_tpus(tmpdir): assert torch.all(torch.eq(model.layer_1.weight, model.layer_3.weight)) +class SubModule(nn.Module): + def __init__(self, layer): + super().__init__() + self.layer = layer + + def forward(self, x): + return self.layer(x) + + +class NestedModule(BoringModel): + def __init__(self): + super().__init__() + self.layer = nn.Linear(32, 10, bias=False) + self.net_a = SubModule(self.layer) + self.layer_2 = nn.Linear(10, 32, bias=False) + self.net_b = SubModule(self.layer) + + def forward(self, x): + x = self.net_a(x) + x = self.layer_2(x) + x = self.net_b(x) + return x + + @RunIf(tpu=True) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_auto_parameters_tying_tpus_nested_module(tmpdir): - class SubModule(nn.Module): - def __init__(self, layer): - super().__init__() - self.layer = layer - - def forward(self, x): - return self.layer(x) - - class NestedModule(BoringModel): - def __init__(self): - super().__init__() - self.layer = nn.Linear(32, 10, bias=False) - self.net_a = SubModule(self.layer) - self.layer_2 = nn.Linear(10, 32, bias=False) - self.net_b = SubModule(self.layer) - - def forward(self, x): - x = self.net_a(x) - x = self.layer_2(x) - x = self.net_b(x) - return x - model = NestedModule() - trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=5, accelerator="tpu", devices="auto", max_epochs=1) trainer.fit(model) @@ -312,8 +314,13 @@ def test_warning_if_tpus_not_used(tpu_available): ("2,", [2]), ], ) +@pytest.mark.parametrize("runtime", ("xrt", "pjrt")) @RunIf(min_python="3.9") # mocking issue -def test_trainer_config_device_ids(devices, expected_device_ids, tpu_available, monkeypatch): +def test_trainer_config_device_ids(devices, expected_device_ids, runtime, tpu_available, monkeypatch): + from torch_xla.experimental import pjrt + + monkeypatch.setattr(pjrt, "using_pjrt", lambda: runtime == "pjrt") + mock = DeviceMock() monkeypatch.setattr(torch, "device", mock) if _IS_WINDOWS: @@ -321,6 +328,7 @@ def test_trainer_config_device_ids(devices, expected_device_ids, tpu_available, monkeypatch.setattr(torch.multiprocessing, "get_all_start_methods", lambda: ["fork", "spawn"]) trainer = Trainer(accelerator="tpu", devices=devices) - assert mock.mock_calls == [call("xla", i + 1) for i in expected_device_ids] + device_offset = int(runtime == "xrt") + assert mock.mock_calls == [call("xla", i + device_offset) for i in expected_device_ids] assert len(trainer.device_ids) == len(expected_device_ids) assert trainer.num_devices == len(expected_device_ids) diff --git a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py index d5d49d9945ebd..4c1808d6f4e46 100644 --- a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py +++ b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py @@ -99,21 +99,20 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> assert cpu_stats_mock.call_count == expected +class AssertTpuMetricsLogger(CSVLogger): + @rank_zero_only + def log_metrics(self, metrics, step=None) -> None: + fields = ["avg. free memory (MB)", "avg. peak memory (MB)"] + for f in fields: + assert any(f in h for h in metrics) + + @RunIf(tpu=True) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_device_stats_monitor_tpu(tmpdir): """Test TPU stats are logged using a logger.""" - model = BoringModel() device_stats = DeviceStatsMonitor() - - class DebugLogger(CSVLogger): - @rank_zero_only - def log_metrics(self, metrics, step=None) -> None: - fields = ["avg. free memory (MB)", "avg. peak memory (MB)"] - for f in fields: - assert any(f in h for h in metrics) - trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, @@ -122,11 +121,19 @@ def log_metrics(self, metrics, step=None) -> None: devices="auto", log_every_n_steps=1, callbacks=[device_stats], - logger=DebugLogger(tmpdir), + logger=AssertTpuMetricsLogger(tmpdir), enable_checkpointing=False, enable_progress_bar=False, ) - trainer.fit(model) + + try: + trainer.fit(model) + except RuntimeError as e: + from torch_xla.experimental import pjrt + + if pjrt.using_pjrt() and "GetMemoryInfo not implemented" in str(e): + pytest.xfail("`xm.get_memory_info` is not implemented with PJRT") + raise e def test_device_stats_monitor_no_logger(tmpdir): diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index c16beb16369b2..c004e32ef1d01 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -204,6 +204,7 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N monkeypatch.setattr(lightning.fabric.accelerators.tpu.TPUAccelerator, "auto_device_count", lambda *_: 8) monkeypatch.setitem(sys.modules, "torch_xla", Mock()) monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock()) @pytest.fixture(scope="function") diff --git a/tests/tests_pytorch/models/test_tpu.py b/tests/tests_pytorch/models/test_tpu.py index 0b2d59db27518..d8cba60f808bb 100644 --- a/tests/tests_pytorch/models/test_tpu.py +++ b/tests/tests_pytorch/models/test_tpu.py @@ -164,17 +164,17 @@ def test_model_16bit_multiple_tpu_devices(tmpdir): tpipes.run_model_test(trainer_options, model, with_hpc=False, min_acc=0.05) +class CustomBoringModel(BoringModel): + def validation_step(self, *args, **kwargs): + out = super().validation_step(*args, **kwargs) + self.log("val_loss", out["x"]) + return out + + @RunIf(tpu=True) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_model_tpu_early_stop(tmpdir): """Test if single TPU core training works.""" - - class CustomBoringModel(BoringModel): - def validation_step(self, *args, **kwargs): - out = super().validation_step(*args, **kwargs) - self.log("val_loss", out["x"]) - return out - model = CustomBoringModel() trainer = Trainer( callbacks=[EarlyStopping(monitor="val_loss")], @@ -313,18 +313,18 @@ def test_tpu_sync_dist(): xla_launch(tpu_sync_dist_fn) +class AssertXLADebugModel(BoringModel): + def on_train_start(self): + assert os.environ.get("PT_XLA_DEBUG") == "1", "PT_XLA_DEBUG was not set in environment variables" + + def teardown(self, stage): + assert "PT_XLA_DEBUG" not in os.environ + + @RunIf(tpu=True) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_tpu_debug_mode(tmpdir): """Test if debug mode works on TPU.""" - - class DebugModel(BoringModel): - def on_train_start(self): - assert os.environ.get("PT_XLA_DEBUG") == str(1), "PT_XLA_DEBUG was not set in environment variables" - - def teardown(self, stage): - assert "PT_XLA_DEBUG" not in os.environ - trainer_options = dict( default_root_dir=tmpdir, enable_progress_bar=False, @@ -336,18 +336,23 @@ def teardown(self, stage): strategy=XLAStrategy(debug=True), ) - model = DebugModel() + model = AssertXLADebugModel() tpipes.run_model_test(trainer_options, model, with_hpc=False) +class AssertXLAWorldSizeModel(BoringModel): + def on_train_start(self): + assert os.environ.get("XRT_HOST_WORLD_SIZE") == str(1) + + @RunIf(tpu=True) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_tpu_host_world_size(tmpdir): """Test Host World size env setup on TPU.""" + from torch_xla.experimental import pjrt - class DebugModel(BoringModel): - def on_train_start(self): - assert os.environ.get("XRT_HOST_WORLD_SIZE") == str(1) + if pjrt.using_pjrt(): + pytest.skip("PJRT doesn't set 'XRT_HOST_WORLD_SIZE'") trainer_options = dict( default_root_dir=tmpdir, @@ -359,7 +364,7 @@ def on_train_start(self): limit_val_batches=0.4, ) - model = DebugModel() + model = AssertXLAWorldSizeModel() assert "XRT_HOST_WORLD_SIZE" not in os.environ tpipes.run_model_test(trainer_options, model, with_hpc=False) assert "XRT_HOST_WORLD_SIZE" not in os.environ