Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix DDP support #1182

Merged
merged 18 commits into from
Feb 22, 2022
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed a bug where DDP would not work with Flash tasks ([#1182](https://github.com/PyTorchLightning/lightning-flash/pull/1182))

## [0.7.0] - 2022-02-15

### Added
Expand Down
37 changes: 0 additions & 37 deletions flash/core/data/io/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,43 +300,6 @@ def predict_load_sample(self, sample: Dict[str, Any]) -> Any:
"""
return self.load_sample(sample)

def __getstate__(self):
"""Temporarily override pickle behaviour.

TODO: New DataPipeline should avoid this being pickled.
"""
state = self.__dict__.copy()
state.pop("data")
if "data_iter" in state:
state.pop("data_iter")
return state

def __setstate__(self, newstate):
"""Temporarily override pickle behaviour.

TODO: New DataPipeline should avoid this being pickled.
"""
newstate["data"] = None
self.__dict__.update(newstate)

def __copy__(self):
"""The default copy implementation seems to use ``__getstate__`` and ``__setstate__`` so we override it
here with a custom implementation to ensure that it includes the data list."""
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result

def __deepcopy__(self, memo):
"""The default deepcopy implementation seems to use ``__getstate__`` and ``__setstate__`` so we override it
here with a custom implementation to ensure that it includes the data list."""
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, deepcopy(v, memo))
return result
Comment on lines -303 to -338
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krshrimali This is the main fix. We used to have a bug where the data was accidentally included in the checkpoint. We patched that by adding this overrides. But then DDP spawn needs to pickle the data to send it to each process so this causes problems. We refactored away the bit that got this included in the checkpoint so now can be safely removed 😃

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thank you so much for the explanation, @ethanwharris!


def __bool__(self):
"""If ``self.data`` is ``None`` then the ``InputBase`` is considered falsey.

Expand Down
53 changes: 53 additions & 0 deletions flash/core/data/io/transform_predictions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools

import pytorch_lightning as pl
from pytorch_lightning import Callback

from flash.core.data.io.output import Output
from flash.core.data.io.output_transform import OutputTransform


class TransformPredictions(Callback):
"""``TransformPredictions`` is a :class:`~pytorch_lightning.callbacks.base.Callback` which can be used to apply an
:class:`~flash.core.data.io.output_transform.OutputTransform` and an :class:`~flash.core.data.io.output.Output` to
model predictions.

Args:
output_transform: The :class:`~flash.core.data.io.output_transform.OutputTransform` to apply.
output: The :class:`~flash.core.data.io.output.Output` to apply.
"""

def __init__(self, output_transform: OutputTransform, output: Output):
super().__init__()

self.output_transform = output_transform
self.output = output

def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
predict_step = pl_module.predict_step

@functools.wraps(predict_step)
def wrapper(*args, **kwargs):
predictions = predict_step(*args, **kwargs)
if predictions is not None:
predictions = self.output_transform(predictions)
predictions = [self.output(prediction) for prediction in predictions]
return predictions
Comment on lines +44 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, when do you think predictions would be None? Should that be counted as a failure? Or a warning be raised that the OutputTransform and Output instances passed were not used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are some cases where it can be None but not sure, it may just be within our tests that it can be None. But yeah, could be better to have an error there

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll also see if there is a possibility that predictions can be None, but for now - I guess we can merge this PR and create a small follow-up PR if required (for the error).


pl_module.predict_step = wrapper

def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
pl_module.predict_step = pl_module.predict_step.__wrapped__
33 changes: 10 additions & 23 deletions flash/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import inspect
import warnings
from argparse import ArgumentParser, Namespace
Expand All @@ -30,6 +28,7 @@
import flash
from flash.core.data.io.output import Output
from flash.core.data.io.output_transform import OutputTransform
from flash.core.data.io.transform_predictions import TransformPredictions
from flash.core.model import Task
from flash.core.registry import FlashRegistry

Expand Down Expand Up @@ -79,7 +78,7 @@ class Trainer(PlTrainer):
def __init__(self, *args, **kwargs):
if flash._IS_TESTING:
if torch.cuda.is_available():
kwargs["gpus"] = 1
kwargs["gpus"] = torch.cuda.device_count()
Borda marked this conversation as resolved.
Show resolved Hide resolved
Borda marked this conversation as resolved.
Show resolved Hide resolved
kwargs["max_epochs"] = 3
kwargs["limit_train_batches"] = 1.0
kwargs["limit_val_batches"] = 1.0
Expand Down Expand Up @@ -162,24 +161,6 @@ def finetune(
self._resolve_callbacks(model, strategy, train_bn=train_bn)
return super().fit(model, train_dataloader, val_dataloaders, datamodule)

@contextlib.contextmanager
def _wrap_predict_step(self, model, output_transform, output) -> None:
predict_step = model.predict_step

@functools.wraps(predict_step)
def wrapper(*args, **kwargs):
predictions = predict_step(*args, **kwargs)
if predictions is not None:
predictions = output_transform(predictions)
predictions = [output(prediction) for prediction in predictions]
return predictions

model.predict_step = wrapper
try:
yield
finally:
model.predict_step = predict_step

def predict(
self,
model: Optional[LightningModule] = None,
Expand Down Expand Up @@ -210,8 +191,14 @@ def predict(
if isinstance(output, str) and isinstance(model, Task):
output = getattr(model, "outputs", FlashRegistry("outputs")).get(output).from_task(model)

with self._wrap_predict_step(model, output_transform, output):
return super().predict(model, dataloaders, **kwargs)
old_callbacks = self.callbacks
self.callbacks = self._merge_callbacks(self.callbacks, [TransformPredictions(output_transform, output)])

result = super().predict(model, dataloaders, **kwargs)

self.callbacks = old_callbacks

return result

def _resolve_callbacks(
self,
Expand Down
2 changes: 1 addition & 1 deletion flash/image/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,6 @@ def serve(
def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
"""This function is used only for debugging usage with CI."""
if self.hparams.multi_label:
assert history[-1]["val_f1"] > 0.30, history[-1]["val_f1"]
assert history[-1]["val_f1score"] > 0.30, history[-1]["val_f1score"]
else:
assert history[-1]["val_accuracy"] > 0.85, history[-1]["val_accuracy"]
2 changes: 1 addition & 1 deletion flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,4 @@ def serve(
@staticmethod
def _ci_benchmark_fn(history: List[Dict[str, Any]]):
"""This function is used only for debugging usage with CI."""
assert history[-1]["val_iou"] > 0.2
assert history[-1]["val_jaccardindex"] > 0.2
2 changes: 1 addition & 1 deletion flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
"""This function is used only for debugging usage with CI."""
if self.hparams.multi_label:
assert history[-1]["val_f1"] > 0.40, history[-1]["val_f1"]
assert history[-1]["val_f1score"] > 0.40, history[-1]["val_f1score"]
else:
assert history[-1]["val_accuracy"] > 0.70, history[-1]["val_accuracy"]

Expand Down
2 changes: 1 addition & 1 deletion flash/text/question_answering/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def load_data(

if flash._IS_TESTING:
# NOTE: must subset in this way to return a Dataset
hf_dataset = hf_dataset.select(range(20))
hf_dataset = [sample for sample in hf_dataset.select(range(40), keep_in_memory=True)]

return hf_dataset

Expand Down
2 changes: 1 addition & 1 deletion flash/text/seq2seq/core/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def load_data(

if flash._IS_TESTING:
# NOTE: must subset in this way to return a Dataset
hf_dataset = hf_dataset.select(range(20))
hf_dataset = [sample for sample in hf_dataset.select(range(40), keep_in_memory=True)]

return hf_dataset

Expand Down
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pytest-flake8
flake8
pytest-doctestplus>=0.9.0
pytest-rerunfailures>=10.0
pytest-forked
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

# install pkg
check-manifest
Expand Down
9 changes: 8 additions & 1 deletion tests/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from unittest import mock

import pytest
import torch.cuda
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

from flash.core.utilities.imports import (
_AUDIO_TESTING,
Expand All @@ -30,6 +31,7 @@
_VIDEO_TESTING,
)
from tests.examples.utils import run_test
from tests.helpers.forked import forked

root = Path(__file__).parent.parent.parent

Expand Down Expand Up @@ -82,7 +84,11 @@
marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"),
),
pytest.param(
"style_transfer.py", marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed")
"style_transfer.py",
marks=[
pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"),
pytest.mark.skipif(torch.cuda.device_count() >= 2, reason="PyStiche doesn't support DDP"),
],
),
pytest.param(
"summarization.py", marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed")
Expand Down Expand Up @@ -141,5 +147,6 @@
),
],
)
@forked
def test_example(tmpdir, file):
run_test(str(root / "flash_examples" / file))
9 changes: 6 additions & 3 deletions tests/examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ def call_script(
timeout: Optional[int] = 60 * 10,
) -> Tuple[int, str, str]:
with open(filepath) as original:
data = original.read()
data = original.readlines()

with open(filepath, "w") as modified:
modified.write("import pytorch_lightning as pl\npl.seed_everything(42)\n" + data)
modified.write("import pytorch_lightning as pl\npl.seed_everything(42)\n")
modified.write("if __name__ == '__main__':\n")
for line in data:
modified.write(f" {line}\n")

if args is None:
args = []
Expand All @@ -42,7 +45,7 @@ def call_script(
stderr = stderr.decode("utf-8")

with open(filepath, "w") as modified:
modified.write(data)
modified.writelines(data)

return p.returncode, stdout, stderr

Expand Down
23 changes: 23 additions & 0 deletions tests/helpers/forked.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest


def forked(callable):
# PyTest forked not available in Windows
if os.name == "nt":
return callable
return pytest.mark.forked(callable)