From 32ac40509b7e1bdd5a301fa0a794d0c2e4fbbc26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 22 Oct 2021 18:39:56 +0200 Subject: [PATCH] Revert "Add XLACheckpointIO (#9972)" This reverts commit aa1540410ff55854e050ff117c2d66f22741d182. --- CHANGELOG.md | 4 -- docs/source/api_references.rst | 1 - pytorch_lightning/plugins/__init__.py | 2 - pytorch_lightning/plugins/io/__init__.py | 1 - pytorch_lightning/plugins/io/xla_plugin.py | 43 ------------------- .../plugins/training_type/single_tpu.py | 23 +++++++--- .../plugins/training_type/tpu_spawn.py | 29 +++++++------ tests/accelerators/test_tpu.py | 8 +--- 8 files changed, 34 insertions(+), 77 deletions(-) delete mode 100644 pytorch_lightning/plugins/io/xla_plugin.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 398d038595b09..7d39030cfcab0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -216,10 +216,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023)) - -- Added `XLACheckpointIO` plugin ([#9972](https://github.com/PyTorchLightning/pytorch-lightning/pull/9972)) - - ### Changed - Setting `Trainer(accelerator="ddp_cpu")` now does not spawn a subprocess if `num_processes` is kept `1` along with `num_nodes > 1` ([#9603](https://github.com/PyTorchLightning/pytorch-lightning/pull/9603)). diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst index b2d546a158f42..4b3816e567e2d 100644 --- a/docs/source/api_references.rst +++ b/docs/source/api_references.rst @@ -211,7 +211,6 @@ Checkpoint IO Plugins CheckpointIO TorchCheckpointIO - XLACheckpointIO Profiler API ------------ diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index 0194591bfc06c..292eb582548d9 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -4,7 +4,6 @@ from pytorch_lightning.plugins.environments import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO -from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.plugins_registry import ( # noqa: F401 call_training_type_register_plugins, TrainingTypePluginsRegistry, @@ -41,7 +40,6 @@ __all__ = [ "CheckpointIO", "TorchCheckpointIO", - "XLACheckpointIO", "ApexMixedPrecisionPlugin", "DataParallelPlugin", "DDP2Plugin", diff --git a/pytorch_lightning/plugins/io/__init__.py b/pytorch_lightning/plugins/io/__init__.py index 1b14eee6ec4f2..232f582c1a520 100644 --- a/pytorch_lightning/plugins/io/__init__.py +++ b/pytorch_lightning/plugins/io/__init__.py @@ -13,4 +13,3 @@ # limitations under the License. from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO # noqa: F401 from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO # noqa: F401 -from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO # noqa: F401 diff --git a/pytorch_lightning/plugins/io/xla_plugin.py b/pytorch_lightning/plugins/io/xla_plugin.py deleted file mode 100644 index c40b6a1ada037..0000000000000 --- a/pytorch_lightning/plugins/io/xla_plugin.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Dict, Optional - -from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO -from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE -from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.types import _PATH - -if _TPU_AVAILABLE: - import torch_xla.core.xla_model as xm - -if _OMEGACONF_AVAILABLE: - from omegaconf import DictConfig, ListConfig, OmegaConf - - -class XLACheckpointIO(TorchCheckpointIO): - """CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies.""" - - def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: - """Save model/training states as a checkpoint file through state-dump and file-write. - - Args: - checkpoint: dict containing model and trainer state - path: write-target path - storage_options: Optional parameters when saving the model/training states. - """ - # Todo: TypeError: 'mappingproxy' object does not support item assignment - # Ref: https://github.com/pytorch/xla/issues/2773 - if _OMEGACONF_AVAILABLE: - checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container) - xm.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, path) diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 9fed2000391dd..81da403e098ea 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -12,12 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Dict, Optional +from typing import Any, Dict from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO -from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin -from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters +from pytorch_lightning.utilities import ( + _OMEGACONF_AVAILABLE, + _TPU_AVAILABLE, + find_shared_parameters, + set_shared_parameters, +) +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import _PATH @@ -25,6 +30,9 @@ if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm +if _OMEGACONF_AVAILABLE: + from omegaconf import DictConfig, ListConfig, OmegaConf + class SingleTPUPlugin(SingleDevicePlugin): """Plugin for training on a single TPU device.""" @@ -32,13 +40,11 @@ class SingleTPUPlugin(SingleDevicePlugin): def __init__( self, device: int, - checkpoint_io: Optional[CheckpointIO] = None, debug: bool = False, ): device = xm.xla_device(device) - checkpoint_io = checkpoint_io or XLACheckpointIO() - super().__init__(device=device, checkpoint_io=checkpoint_io) + super().__init__(device=device) self.debug = debug self.tpu_local_core_rank = 0 @@ -79,7 +85,10 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None: checkpoint: dict containing model and trainer state filepath: write-target file's path """ - return self.checkpoint_io.save_checkpoint(checkpoint, filepath) + # Related Issue: https://github.com/pytorch/xla/issues/2773 + if _OMEGACONF_AVAILABLE: + checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container) + self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath) def teardown(self) -> None: # TPU teardown diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index f8968a69ceed1..6d18612b94f50 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -25,11 +25,17 @@ import pytorch_lightning as pl from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO -from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters +from pytorch_lightning.utilities import ( + _OMEGACONF_AVAILABLE, + _TPU_AVAILABLE, + find_shared_parameters, + rank_zero_warn, + set_shared_parameters, +) +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -46,19 +52,15 @@ else: xm, xmp, MpDeviceLoader, rendezvous = [None] * 4 +if _OMEGACONF_AVAILABLE: + from omegaconf import DictConfig, ListConfig, OmegaConf + class TPUSpawnPlugin(DDPSpawnPlugin): """Plugin for training multiple TPU devices using the :func:`torch.multiprocessing.spawn` method.""" - def __init__( - self, - parallel_devices: Optional[List[int]] = None, - checkpoint_io: Optional[CheckpointIO] = None, - debug: bool = False, - **_: Any - ) -> None: - checkpoint_io = checkpoint_io or XLACheckpointIO() - super().__init__(parallel_devices=parallel_devices, checkpoint_io=checkpoint_io) + def __init__(self, parallel_devices: Optional[List[int]] = None, debug: bool = False, **_: Any) -> None: + super().__init__(parallel_devices=parallel_devices) self.debug = debug self.tpu_local_core_rank = 0 self.tpu_global_core_rank = 0 @@ -315,7 +317,10 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None: checkpoint: dict containing model and trainer state filepath: write-target file's path """ - return self.checkpoint_io.save_checkpoint(checkpoint, filepath) + # Todo: TypeError: 'mappingproxy' object does not support item assignment + if _OMEGACONF_AVAILABLE: + checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container) + self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath) def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: """ diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index b93b46e8e1407..62789d1a541dd 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -22,7 +22,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.tpu import TPUAccelerator -from pytorch_lightning.plugins import TPUPrecisionPlugin, TPUSpawnPlugin, XLACheckpointIO +from pytorch_lightning.plugins import TPUPrecisionPlugin, TPUSpawnPlugin from pytorch_lightning.utilities import find_shared_parameters from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel @@ -294,9 +294,3 @@ def test_tpu_invalid_raises(): accelerator = TPUAccelerator(TPUPrecisionPlugin(), object()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugi"): accelerator.setup(object()) - - -@RunIf(tpu=True) -def test_xla_checkpoint_plugin_being_default(): - trainer = Trainer(tpu_cores=8) - assert isinstance(trainer.training_type_plugin.checkpoint_io, XLACheckpointIO)