Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix DDP support (#1182)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
ethanwharris and Borda committed Mar 1, 2022
1 parent 847d751 commit c922d3d
Show file tree
Hide file tree
Showing 14 changed files with 113 additions and 76 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed a bug where DDP would not work with Flash tasks ([#1182](https://github.com/PyTorchLightning/lightning-flash/pull/1182))

## [0.7.0] - 2022-02-15

### Added
Expand Down
37 changes: 0 additions & 37 deletions flash/core/data/io/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,43 +300,6 @@ def predict_load_sample(self, sample: Dict[str, Any]) -> Any:
"""
return self.load_sample(sample)

def __getstate__(self):
"""Temporarily override pickle behaviour.
TODO: New DataPipeline should avoid this being pickled.
"""
state = self.__dict__.copy()
state.pop("data")
if "data_iter" in state:
state.pop("data_iter")
return state

def __setstate__(self, newstate):
"""Temporarily override pickle behaviour.
TODO: New DataPipeline should avoid this being pickled.
"""
newstate["data"] = None
self.__dict__.update(newstate)

def __copy__(self):
"""The default copy implementation seems to use ``__getstate__`` and ``__setstate__`` so we override it
here with a custom implementation to ensure that it includes the data list."""
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result

def __deepcopy__(self, memo):
"""The default deepcopy implementation seems to use ``__getstate__`` and ``__setstate__`` so we override it
here with a custom implementation to ensure that it includes the data list."""
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, deepcopy(v, memo))
return result

def __bool__(self):
"""If ``self.data`` is ``None`` then the ``InputBase`` is considered falsey.
Expand Down
53 changes: 53 additions & 0 deletions flash/core/data/io/transform_predictions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools

import pytorch_lightning as pl
from pytorch_lightning import Callback

from flash.core.data.io.output import Output
from flash.core.data.io.output_transform import OutputTransform


class TransformPredictions(Callback):
"""``TransformPredictions`` is a :class:`~pytorch_lightning.callbacks.base.Callback` which can be used to apply an
:class:`~flash.core.data.io.output_transform.OutputTransform` and an :class:`~flash.core.data.io.output.Output` to
model predictions.
Args:
output_transform: The :class:`~flash.core.data.io.output_transform.OutputTransform` to apply.
output: The :class:`~flash.core.data.io.output.Output` to apply.
"""

def __init__(self, output_transform: OutputTransform, output: Output):
super().__init__()

self.output_transform = output_transform
self.output = output

def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
predict_step = pl_module.predict_step

@functools.wraps(predict_step)
def wrapper(*args, **kwargs):
predictions = predict_step(*args, **kwargs)
if predictions is not None:
predictions = self.output_transform(predictions)
predictions = [self.output(prediction) for prediction in predictions]
return predictions

pl_module.predict_step = wrapper

def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
pl_module.predict_step = pl_module.predict_step.__wrapped__
33 changes: 10 additions & 23 deletions flash/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import inspect
import warnings
from argparse import ArgumentParser, Namespace
Expand All @@ -30,6 +28,7 @@
import flash
from flash.core.data.io.output import Output
from flash.core.data.io.output_transform import OutputTransform
from flash.core.data.io.transform_predictions import TransformPredictions
from flash.core.model import Task
from flash.core.registry import FlashRegistry

Expand Down Expand Up @@ -79,7 +78,7 @@ class Trainer(PlTrainer):
def __init__(self, *args, **kwargs):
if flash._IS_TESTING:
if torch.cuda.is_available():
kwargs["gpus"] = 1
kwargs["gpus"] = -1
kwargs["max_epochs"] = 3
kwargs["limit_train_batches"] = 1.0
kwargs["limit_val_batches"] = 1.0
Expand Down Expand Up @@ -162,24 +161,6 @@ def finetune(
self._resolve_callbacks(model, strategy, train_bn=train_bn)
return super().fit(model, train_dataloader, val_dataloaders, datamodule)

@contextlib.contextmanager
def _wrap_predict_step(self, model, output_transform, output) -> None:
predict_step = model.predict_step

@functools.wraps(predict_step)
def wrapper(*args, **kwargs):
predictions = predict_step(*args, **kwargs)
if predictions is not None:
predictions = output_transform(predictions)
predictions = [output(prediction) for prediction in predictions]
return predictions

model.predict_step = wrapper
try:
yield
finally:
model.predict_step = predict_step

def predict(
self,
model: Optional[LightningModule] = None,
Expand Down Expand Up @@ -210,8 +191,14 @@ def predict(
if isinstance(output, str) and isinstance(model, Task):
output = getattr(model, "outputs", FlashRegistry("outputs")).get(output).from_task(model)

with self._wrap_predict_step(model, output_transform, output):
return super().predict(model, dataloaders, **kwargs)
old_callbacks = self.callbacks
self.callbacks = self._merge_callbacks(self.callbacks, [TransformPredictions(output_transform, output)])

result = super().predict(model, dataloaders, **kwargs)

self.callbacks = old_callbacks

return result

def _resolve_callbacks(
self,
Expand Down
2 changes: 1 addition & 1 deletion flash/image/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,6 @@ def serve(
def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
"""This function is used only for debugging usage with CI."""
if self.hparams.multi_label:
assert history[-1]["val_f1"] > 0.30, history[-1]["val_f1"]
assert history[-1]["val_f1score"] > 0.30, history[-1]["val_f1score"]
else:
assert history[-1]["val_accuracy"] > 0.85, history[-1]["val_accuracy"]
2 changes: 1 addition & 1 deletion flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,4 @@ def serve(
@staticmethod
def _ci_benchmark_fn(history: List[Dict[str, Any]]):
"""This function is used only for debugging usage with CI."""
assert history[-1]["val_iou"] > 0.2
assert history[-1]["val_jaccardindex"] > 0.2
2 changes: 1 addition & 1 deletion flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
"""This function is used only for debugging usage with CI."""
if self.hparams.multi_label:
assert history[-1]["val_f1"] > 0.40, history[-1]["val_f1"]
assert history[-1]["val_f1score"] > 0.40, history[-1]["val_f1score"]
else:
assert history[-1]["val_accuracy"] > 0.70, history[-1]["val_accuracy"]

Expand Down
2 changes: 1 addition & 1 deletion flash/text/question_answering/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def load_data(

if flash._IS_TESTING:
# NOTE: must subset in this way to return a Dataset
hf_dataset = hf_dataset.select(range(20))
hf_dataset = [sample for sample in hf_dataset.select(range(40), keep_in_memory=True)]

return hf_dataset

Expand Down
12 changes: 5 additions & 7 deletions flash/text/question_answering/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,13 +291,11 @@ def common_step(self, prefix: str, batch: Any) -> torch.Tensor:
self.log_dict(result, on_step=False, on_epoch=True, prog_bar=False)

def compute_metrics(self, generated_tokens, batch):
for example in batch:
predicted_answer = generated_tokens[example["example_id"]]
target_answer = example["answer"]["text"][0] if len(example["answer"]["text"]) > 0 else ""
self.rouge.update(predicted_answer, target_answer)

result = self.rouge.compute()
return result
predicted_answers = [generated_tokens[example["example_id"]] for example in batch]
target_answers = [
example["answer"]["text"][0] if len(example["answer"]["text"]) > 0 else "" for example in batch
]
return self.rouge(predicted_answers, target_answers)

def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
self.common_step("val", batch)
Expand Down
2 changes: 1 addition & 1 deletion flash/text/seq2seq/core/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def load_data(

if flash._IS_TESTING:
# NOTE: must subset in this way to return a Dataset
hf_dataset = hf_dataset.select(range(20))
hf_dataset = [sample for sample in hf_dataset.select(range(40), keep_in_memory=True)]

return hf_dataset

Expand Down
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pytest-flake8
flake8
pytest-doctestplus>=0.9.0
pytest-rerunfailures>=10.0
pytest-forked

# install pkg
check-manifest
Expand Down
9 changes: 8 additions & 1 deletion tests/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from unittest import mock

import pytest
import torch

from flash.core.utilities.imports import (
_AUDIO_TESTING,
Expand All @@ -30,6 +31,7 @@
_VIDEO_TESTING,
)
from tests.examples.utils import run_test
from tests.helpers.forked import forked

root = Path(__file__).parent.parent.parent

Expand Down Expand Up @@ -82,7 +84,11 @@
marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"),
),
pytest.param(
"style_transfer.py", marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed")
"style_transfer.py",
marks=[
pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"),
pytest.mark.skipif(torch.cuda.device_count() >= 2, reason="PyStiche doesn't support DDP"),
],
),
pytest.param(
"summarization.py", marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed")
Expand Down Expand Up @@ -141,5 +147,6 @@
),
],
)
@forked
def test_example(tmpdir, file):
run_test(str(root / "flash_examples" / file))
9 changes: 6 additions & 3 deletions tests/examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ def call_script(
timeout: Optional[int] = 60 * 10,
) -> Tuple[int, str, str]:
with open(filepath) as original:
data = original.read()
data = original.readlines()

with open(filepath, "w") as modified:
modified.write("import pytorch_lightning as pl\npl.seed_everything(42)\n" + data)
modified.write("import pytorch_lightning as pl\npl.seed_everything(42)\n")
modified.write("if __name__ == '__main__':\n")
for line in data:
modified.write(f" {line}\n")

if args is None:
args = []
Expand All @@ -42,7 +45,7 @@ def call_script(
stderr = stderr.decode("utf-8")

with open(filepath, "w") as modified:
modified.write(data)
modified.writelines(data)

return p.returncode, stdout, stderr

Expand Down
23 changes: 23 additions & 0 deletions tests/helpers/forked.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest


def forked(callable):
# PyTest forked not available in Windows
if os.name == "nt":
return callable
return pytest.mark.forked(callable)

0 comments on commit c922d3d

Please sign in to comment.