Skip to content

Commit

Permalink
Revert "Add XLACheckpointIO (#9972)"
Browse files Browse the repository at this point in the history
This reverts commit aa15404.
  • Loading branch information
awaelchli committed Oct 22, 2021
1 parent c5c5afc commit 32ac405
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 77 deletions.
4 changes: 0 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023))



- Added `XLACheckpointIO` plugin ([#9972](https://github.com/PyTorchLightning/pytorch-lightning/pull/9972))


### Changed

- Setting `Trainer(accelerator="ddp_cpu")` now does not spawn a subprocess if `num_processes` is kept `1` along with `num_nodes > 1` ([#9603](https://github.com/PyTorchLightning/pytorch-lightning/pull/9603)).
Expand Down
1 change: 0 additions & 1 deletion docs/source/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ Checkpoint IO Plugins

CheckpointIO
TorchCheckpointIO
XLACheckpointIO

Profiler API
------------
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.plugins_registry import ( # noqa: F401
call_training_type_register_plugins,
TrainingTypePluginsRegistry,
Expand Down Expand Up @@ -41,7 +40,6 @@
__all__ = [
"CheckpointIO",
"TorchCheckpointIO",
"XLACheckpointIO",
"ApexMixedPrecisionPlugin",
"DataParallelPlugin",
"DDP2Plugin",
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/plugins/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,3 @@
# limitations under the License.
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO # noqa: F401
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO # noqa: F401
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO # noqa: F401
43 changes: 0 additions & 43 deletions pytorch_lightning/plugins/io/xla_plugin.py

This file was deleted.

23 changes: 16 additions & 7 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict, Optional
from typing import Any, Dict

from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
from pytorch_lightning.utilities import (
_OMEGACONF_AVAILABLE,
_TPU_AVAILABLE,
find_shared_parameters,
set_shared_parameters,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _PATH

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm

if _OMEGACONF_AVAILABLE:
from omegaconf import DictConfig, ListConfig, OmegaConf


class SingleTPUPlugin(SingleDevicePlugin):
"""Plugin for training on a single TPU device."""

def __init__(
self,
device: int,
checkpoint_io: Optional[CheckpointIO] = None,
debug: bool = False,
):

device = xm.xla_device(device)
checkpoint_io = checkpoint_io or XLACheckpointIO()
super().__init__(device=device, checkpoint_io=checkpoint_io)
super().__init__(device=device)

self.debug = debug
self.tpu_local_core_rank = 0
Expand Down Expand Up @@ -79,7 +85,10 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
checkpoint: dict containing model and trainer state
filepath: write-target file's path
"""
return self.checkpoint_io.save_checkpoint(checkpoint, filepath)
# Related Issue: https://github.com/pytorch/xla/issues/2773
if _OMEGACONF_AVAILABLE:
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)

def teardown(self) -> None:
# TPU teardown
Expand Down
29 changes: 17 additions & 12 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,17 @@
import pytorch_lightning as pl
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters
from pytorch_lightning.utilities import (
_OMEGACONF_AVAILABLE,
_TPU_AVAILABLE,
find_shared_parameters,
rank_zero_warn,
set_shared_parameters,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -46,19 +52,15 @@
else:
xm, xmp, MpDeviceLoader, rendezvous = [None] * 4

if _OMEGACONF_AVAILABLE:
from omegaconf import DictConfig, ListConfig, OmegaConf


class TPUSpawnPlugin(DDPSpawnPlugin):
"""Plugin for training multiple TPU devices using the :func:`torch.multiprocessing.spawn` method."""

def __init__(
self,
parallel_devices: Optional[List[int]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
debug: bool = False,
**_: Any
) -> None:
checkpoint_io = checkpoint_io or XLACheckpointIO()
super().__init__(parallel_devices=parallel_devices, checkpoint_io=checkpoint_io)
def __init__(self, parallel_devices: Optional[List[int]] = None, debug: bool = False, **_: Any) -> None:
super().__init__(parallel_devices=parallel_devices)
self.debug = debug
self.tpu_local_core_rank = 0
self.tpu_global_core_rank = 0
Expand Down Expand Up @@ -315,7 +317,10 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
checkpoint: dict containing model and trainer state
filepath: write-target file's path
"""
return self.checkpoint_io.save_checkpoint(checkpoint, filepath)
# Todo: TypeError: 'mappingproxy' object does not support item assignment
if _OMEGACONF_AVAILABLE:
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)

def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""
Expand Down
8 changes: 1 addition & 7 deletions tests/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.plugins import TPUPrecisionPlugin, TPUSpawnPlugin, XLACheckpointIO
from pytorch_lightning.plugins import TPUPrecisionPlugin, TPUSpawnPlugin
from pytorch_lightning.utilities import find_shared_parameters
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
Expand Down Expand Up @@ -294,9 +294,3 @@ def test_tpu_invalid_raises():
accelerator = TPUAccelerator(TPUPrecisionPlugin(), object())
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugi"):
accelerator.setup(object())


@RunIf(tpu=True)
def test_xla_checkpoint_plugin_being_default():
trainer = Trainer(tpu_cores=8)
assert isinstance(trainer.training_type_plugin.checkpoint_io, XLACheckpointIO)

0 comments on commit 32ac405

Please sign in to comment.