Skip to content

Commit

Permalink
Replace DataLoader sampler once for IPUs (#8858)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Aug 16, 2021
1 parent 1d2f7e2 commit 93ab24d
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 265 deletions.
17 changes: 1 addition & 16 deletions .azure-pipelines/ipu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,26 +72,11 @@ jobs:
python -c "import poptorch; print(poptorch.__version__)"
displayName: "Check poptorch installation"
- bash: |
wget https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip -P legacy/
unzip -o legacy/checkpoints.zip -d legacy/
ls -l legacy/checkpoints/
displayName: 'Get legacy checkpoints'
- bash: |
source ${{ variables.poplar_sdk }}/poplar-ubuntu*/enable.sh
source ${{ variables.poplar_sdk }}/popart-ubuntu*/enable.sh
export POPTORCH_WAIT_FOR_IPU=1
python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=50
python -m coverage run --source pytorch_lightning -m pytest tests/accelerators/test_ipu.py -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=50
env:
MKL_THREADING_LAYER: "GNU"
displayName: 'Testing: standard'
- bash: |
source ${{ variables.poplar_sdk }}/poplar-ubuntu*/enable.sh
source ${{ variables.poplar_sdk }}/popart-ubuntu*/enable.sh
export POPTORCH_WAIT_FOR_IPU=1
bash tests/special_tests.sh
env:
MKL_THREADING_LAYER: "GNU"
displayName: 'Testing: special'
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
(https://github.com/PyTorchLightning/pytorch-lightning/pull/8608))


- `Trainer.request_dataloader` now takes a `RunningStage` enum instance ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858))

### Deprecated

- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
Expand Down Expand Up @@ -132,6 +134,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `LightningModule.write_predictions` and `LightningModule.write_predictions_dict` ([#](https://github.com/PyTorchLightning/pytorch-lightning/pull/8850))


- Removed reset dataloader hooks to Training Plugins and Accelerators ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858))



### Fixed

Expand Down Expand Up @@ -176,6 +181,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed infinite loop with CycleIterator and multiple loaders ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889))


- Fixed bug where data-loading functions where not getting the correct running stage passed ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858))


## [1.4.0] - 2021-07-27

### Added
Expand Down
16 changes: 0 additions & 16 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,22 +410,6 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I
"""
return self.training_type_plugin.process_dataloader(dataloader)

def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the train dataloader."""
return self.training_type_plugin.on_reset_train_dataloader(dataloader)

def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the val dataloader."""
return self.training_type_plugin.on_reset_val_dataloader(dataloader)

def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the test dataloader."""
return self.training_type_plugin.on_reset_test_dataloader(dataloader)

def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the predict dataloader."""
return self.training_type_plugin.on_reset_predict_dataloader(dataloader)

@property
def results(self) -> Any:
"""
Expand Down
69 changes: 16 additions & 53 deletions pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
from typing import Any, Iterable, List, Optional, Union
from typing import Any, List, Optional, Union

import torch
from torch.utils.data import DataLoader
Expand All @@ -26,7 +25,6 @@
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import _POPTORCH_AVAILABLE
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import get_filesystem
Expand Down Expand Up @@ -112,6 +110,12 @@ def __init__(
options["autoReport.directory"] = self.autoreport_dir
os.environ["POPLAR_ENGINE_OPTIONS"] = json.dumps(options)

def setup(self) -> None:
# patch the dataloader creation function with the custom `poptorch.DataLoader`.
# this violates the intended control flow for the plugins, but since this is experimental, we have chosen
# to use the simpler solution before adding abstractions to override the `DataLoader` class
self.lightning_module.trainer.replace_sampler = self._convert_to_poptorch_loader

def pre_dispatch(self) -> None:
precision = self.lightning_module.trainer.precision
model = LightningIPUModule(self.lightning_module, precision)
Expand Down Expand Up @@ -169,59 +173,16 @@ def inference_opts(self) -> "poptorch.Options":
def lightning_module(self) -> Optional["pl.LightningModule"]:
return self.model.module if isinstance(self.model, LightningIPUModule) else self.model

def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
return self._process_dataloader(dataloader, is_training=True)

def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
return self._process_dataloader(dataloader, is_training=False)

def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
return self._process_dataloader(dataloader, is_training=False)

def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
return self._process_dataloader(dataloader, is_training=False)

def _process_dataloader(
self, dataloader: Union[Iterable, DataLoader], is_training: bool
) -> Union[Iterable, DataLoader]:
if isinstance(dataloader, CombinedLoader):
dataloader.loaders = apply_to_collection(
dataloader.loaders, DataLoader, self._process_dataloader, is_training
)
return dataloader
if isinstance(dataloader, list):
dataloader = apply_to_collection(dataloader, DataLoader, self._process_dataloader, is_training)
return dataloader
if not isinstance(dataloader, poptorch.DataLoader):
opts = self.training_opts if is_training else self.inference_opts
dataloader = self._convert_to_poptorch_loader(dataloader=dataloader, opts=opts)
return dataloader

def _convert_to_poptorch_loader(
self, dataloader: Union[Iterable, DataLoader], opts: "poptorch.Options"
) -> Union[Iterable, DataLoader]:
skip_keys = ("sampler", "batch_sampler", "dataset_kind")

attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}

params = set(inspect.signature(dataloader.__init__).parameters)
contains_dataset = True

if type(dataloader) is not DataLoader:
contains_dataset = "dataset" in params
params.update(inspect.signature(DataLoader.__init__).parameters)

dl_args = {name: attrs[name] for name in params if name in attrs and name not in skip_keys}

multiprocessing_context = dataloader.multiprocessing_context
dl_args["multiprocessing_context"] = multiprocessing_context
if not contains_dataset:
dl_args.pop("dataset")
self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None
) -> "poptorch.DataLoader":
# use full path to avoid circular imports
dl_kwargs = pl.trainer.trainer.TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler)
# Override to drop last uneven batch, as IPUs does not support uneven inputs.
dl_args["drop_last"] = True
dl_kwargs["drop_last"] = True

dataloader = poptorch.DataLoader(**dl_args, options=opts)
dataloader.multiprocessing_context = multiprocessing_context
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
dataloader = poptorch.DataLoader(**dl_kwargs, options=opts)
return dataloader

@property
Expand Down Expand Up @@ -291,6 +252,8 @@ def predict_step(self, *args, **kwargs):
return self.poptorch_models[RunningStage.PREDICTING](*args, **kwargs)

def teardown(self) -> None:
# undo dataloader patching
self.lightning_module.trainer.replace_sampler = pl.trainer.trainer.TrainerDataLoadingMixin.replace_sampler
for model in self.poptorch_models.values():
model.destroy()

Expand Down
16 changes: 0 additions & 16 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,22 +212,6 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I
"""
return dataloader

def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the train dataloader."""
return dataloader

def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the val dataloader."""
return dataloader

def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the test dataloader."""
return dataloader

def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the predict dataloader."""
return dataloader

def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
return trainer.init_optimizers(model)

Expand Down
Loading

0 comments on commit 93ab24d

Please sign in to comment.