Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XLACheckpointIO #9972

Merged
merged 7 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))
* Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009))


- Added `XLACheckpointIO` plugin ([#9972](https://github.com/PyTorchLightning/pytorch-lightning/pull/9972))


### Changed

- Setting `Trainer(accelerator="ddp_cpu")` now does not spawn a subprocess if `num_processes` is kept `1` along with `num_nodes > 1` ([#9603](https://github.com/PyTorchLightning/pytorch-lightning/pull/9603)).
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ Checkpoint IO Plugins

CheckpointIO
TorchCheckpointIO
XLACheckpointIO

Profiler API
------------
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.plugins_registry import ( # noqa: F401
call_training_type_register_plugins,
TrainingTypePluginsRegistry,
Expand Down Expand Up @@ -40,6 +41,7 @@
__all__ = [
"CheckpointIO",
"TorchCheckpointIO",
"XLACheckpointIO",
"ApexMixedPrecisionPlugin",
"DataParallelPlugin",
"DDP2Plugin",
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO # noqa: F401
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO # noqa: F401
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO # noqa: F401
46 changes: 46 additions & 0 deletions pytorch_lightning/plugins/io/xla_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Optional

from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.types import _PATH

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm

if _OMEGACONF_AVAILABLE:
from omegaconf import DictConfig, ListConfig, OmegaConf

log = logging.getLogger(__name__)
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved


class XLACheckpointIO(TorchCheckpointIO):
"""CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies."""

def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
checkpoint: dict containing model and trainer state
path: write-target path
storage_options: Optional parameters when saving the model/training states.
"""
# Todo: TypeError: 'mappingproxy' object does not support item assignment
# Ref: https://github.com/pytorch/xla/issues/2773
if _OMEGACONF_AVAILABLE:
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
xm.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, path)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
23 changes: 7 additions & 16 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict
from typing import Any, Dict, Optional

from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.utilities import (
_OMEGACONF_AVAILABLE,
_TPU_AVAILABLE,
find_shared_parameters,
set_shared_parameters,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _PATH

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm

if _OMEGACONF_AVAILABLE:
from omegaconf import DictConfig, ListConfig, OmegaConf


class SingleTPUPlugin(SingleDevicePlugin):
"""Plugin for training on a single TPU device."""

def __init__(
self,
device: int,
checkpoint_io: Optional[CheckpointIO] = None,
debug: bool = False,
):

device = xm.xla_device(device)
super().__init__(device=device)
checkpoint_io = checkpoint_io or XLACheckpointIO()
super().__init__(device=device, checkpoint_io=checkpoint_io)

self.debug = debug
self.tpu_local_core_rank = 0
Expand Down Expand Up @@ -85,10 +79,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
checkpoint: dict containing model and trainer state
filepath: write-target file's path
"""
# Related Issue: https://github.com/pytorch/xla/issues/2773
if _OMEGACONF_AVAILABLE:
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)
return self.checkpoint_io.save_checkpoint(checkpoint, filepath)

def teardown(self) -> None:
# TPU teardown
Expand Down
29 changes: 12 additions & 17 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,11 @@
import pytorch_lightning as pl
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import (
_OMEGACONF_AVAILABLE,
_TPU_AVAILABLE,
find_shared_parameters,
rank_zero_warn,
set_shared_parameters,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -52,15 +46,19 @@
else:
xm, xmp, MpDeviceLoader, rendezvous = [None] * 4

if _OMEGACONF_AVAILABLE:
from omegaconf import DictConfig, ListConfig, OmegaConf


class TPUSpawnPlugin(DDPSpawnPlugin):
"""Plugin for training multiple TPU devices using the :func:`torch.multiprocessing.spawn` method."""

def __init__(self, parallel_devices: Optional[List[int]] = None, debug: bool = False, **_: Any) -> None:
super().__init__(parallel_devices=parallel_devices)
def __init__(
self,
parallel_devices: Optional[List[int]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
debug: bool = False,
**_: Any
) -> None:
checkpoint_io = checkpoint_io or XLACheckpointIO()
super().__init__(parallel_devices=parallel_devices, checkpoint_io=checkpoint_io)
self.debug = debug
self.tpu_local_core_rank = 0
self.tpu_global_core_rank = 0
Expand Down Expand Up @@ -321,10 +319,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
checkpoint: dict containing model and trainer state
filepath: write-target file's path
"""
# Todo: TypeError: 'mappingproxy' object does not support item assignment
if _OMEGACONF_AVAILABLE:
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)
return self.checkpoint_io.save_checkpoint(checkpoint, filepath)

def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""
Expand Down
8 changes: 7 additions & 1 deletion tests/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.plugins import TPUPrecisionPlugin, TPUSpawnPlugin
from pytorch_lightning.plugins import TPUPrecisionPlugin, TPUSpawnPlugin, XLACheckpointIO
from pytorch_lightning.utilities import find_shared_parameters
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
Expand Down Expand Up @@ -294,3 +294,9 @@ def test_tpu_invalid_raises():
accelerator = TPUAccelerator(TPUPrecisionPlugin(), object())
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugi"):
accelerator.setup(object())


@RunIf(tpu=True)
def test_xla_checkpoint_plugin_being_default():
trainer = Trainer(tpu_cores=8)
assert isinstance(trainer.training_type_plugin.checkpoint_io, XLACheckpointIO)
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved