diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c529d56a3..b43905ced2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed a bug where DDP would not work with Flash tasks ([#1182](https://github.com/PyTorchLightning/lightning-flash/pull/1182)) + ## [0.7.0] - 2022-02-15 ### Added diff --git a/flash/core/data/io/input.py b/flash/core/data/io/input.py index 04dba6e7f8..3fb6468855 100644 --- a/flash/core/data/io/input.py +++ b/flash/core/data/io/input.py @@ -300,43 +300,6 @@ def predict_load_sample(self, sample: Dict[str, Any]) -> Any: """ return self.load_sample(sample) - def __getstate__(self): - """Temporarily override pickle behaviour. - - TODO: New DataPipeline should avoid this being pickled. - """ - state = self.__dict__.copy() - state.pop("data") - if "data_iter" in state: - state.pop("data_iter") - return state - - def __setstate__(self, newstate): - """Temporarily override pickle behaviour. - - TODO: New DataPipeline should avoid this being pickled. - """ - newstate["data"] = None - self.__dict__.update(newstate) - - def __copy__(self): - """The default copy implementation seems to use ``__getstate__`` and ``__setstate__`` so we override it - here with a custom implementation to ensure that it includes the data list.""" - cls = self.__class__ - result = cls.__new__(cls) - result.__dict__.update(self.__dict__) - return result - - def __deepcopy__(self, memo): - """The default deepcopy implementation seems to use ``__getstate__`` and ``__setstate__`` so we override it - here with a custom implementation to ensure that it includes the data list.""" - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - for k, v in self.__dict__.items(): - setattr(result, k, deepcopy(v, memo)) - return result - def __bool__(self): """If ``self.data`` is ``None`` then the ``InputBase`` is considered falsey. diff --git a/flash/core/data/io/transform_predictions.py b/flash/core/data/io/transform_predictions.py new file mode 100644 index 0000000000..a486c1659e --- /dev/null +++ b/flash/core/data/io/transform_predictions.py @@ -0,0 +1,53 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools + +import pytorch_lightning as pl +from pytorch_lightning import Callback + +from flash.core.data.io.output import Output +from flash.core.data.io.output_transform import OutputTransform + + +class TransformPredictions(Callback): + """``TransformPredictions`` is a :class:`~pytorch_lightning.callbacks.base.Callback` which can be used to apply an + :class:`~flash.core.data.io.output_transform.OutputTransform` and an :class:`~flash.core.data.io.output.Output` to + model predictions. + + Args: + output_transform: The :class:`~flash.core.data.io.output_transform.OutputTransform` to apply. + output: The :class:`~flash.core.data.io.output.Output` to apply. + """ + + def __init__(self, output_transform: OutputTransform, output: Output): + super().__init__() + + self.output_transform = output_transform + self.output = output + + def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + predict_step = pl_module.predict_step + + @functools.wraps(predict_step) + def wrapper(*args, **kwargs): + predictions = predict_step(*args, **kwargs) + if predictions is not None: + predictions = self.output_transform(predictions) + predictions = [self.output(prediction) for prediction in predictions] + return predictions + + pl_module.predict_step = wrapper + + def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + pl_module.predict_step = pl_module.predict_step.__wrapped__ diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 3245856f4a..82339673c7 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -import functools import inspect import warnings from argparse import ArgumentParser, Namespace @@ -30,6 +28,7 @@ import flash from flash.core.data.io.output import Output from flash.core.data.io.output_transform import OutputTransform +from flash.core.data.io.transform_predictions import TransformPredictions from flash.core.model import Task from flash.core.registry import FlashRegistry @@ -79,7 +78,7 @@ class Trainer(PlTrainer): def __init__(self, *args, **kwargs): if flash._IS_TESTING: if torch.cuda.is_available(): - kwargs["gpus"] = 1 + kwargs["gpus"] = -1 kwargs["max_epochs"] = 3 kwargs["limit_train_batches"] = 1.0 kwargs["limit_val_batches"] = 1.0 @@ -162,24 +161,6 @@ def finetune( self._resolve_callbacks(model, strategy, train_bn=train_bn) return super().fit(model, train_dataloader, val_dataloaders, datamodule) - @contextlib.contextmanager - def _wrap_predict_step(self, model, output_transform, output) -> None: - predict_step = model.predict_step - - @functools.wraps(predict_step) - def wrapper(*args, **kwargs): - predictions = predict_step(*args, **kwargs) - if predictions is not None: - predictions = output_transform(predictions) - predictions = [output(prediction) for prediction in predictions] - return predictions - - model.predict_step = wrapper - try: - yield - finally: - model.predict_step = predict_step - def predict( self, model: Optional[LightningModule] = None, @@ -210,8 +191,14 @@ def predict( if isinstance(output, str) and isinstance(model, Task): output = getattr(model, "outputs", FlashRegistry("outputs")).get(output).from_task(model) - with self._wrap_predict_step(model, output_transform, output): - return super().predict(model, dataloaders, **kwargs) + old_callbacks = self.callbacks + self.callbacks = self._merge_callbacks(self.callbacks, [TransformPredictions(output_transform, output)]) + + result = super().predict(model, dataloaders, **kwargs) + + self.callbacks = old_callbacks + + return result def _resolve_callbacks( self, diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index a56700b842..0b7261b53c 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -181,6 +181,6 @@ def serve( def _ci_benchmark_fn(self, history: List[Dict[str, Any]]): """This function is used only for debugging usage with CI.""" if self.hparams.multi_label: - assert history[-1]["val_f1"] > 0.30, history[-1]["val_f1"] + assert history[-1]["val_f1score"] > 0.30, history[-1]["val_f1score"] else: assert history[-1]["val_accuracy"] > 0.85, history[-1]["val_accuracy"] diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index bb8cc9d1e3..883c999300 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -198,4 +198,4 @@ def serve( @staticmethod def _ci_benchmark_fn(history: List[Dict[str, Any]]): """This function is used only for debugging usage with CI.""" - assert history[-1]["val_iou"] > 0.2 + assert history[-1]["val_jaccardindex"] > 0.2 diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index 2448c99e91..6cc29f6778 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -127,7 +127,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A def _ci_benchmark_fn(self, history: List[Dict[str, Any]]): """This function is used only for debugging usage with CI.""" if self.hparams.multi_label: - assert history[-1]["val_f1"] > 0.40, history[-1]["val_f1"] + assert history[-1]["val_f1score"] > 0.40, history[-1]["val_f1score"] else: assert history[-1]["val_accuracy"] > 0.70, history[-1]["val_accuracy"] diff --git a/flash/text/question_answering/input.py b/flash/text/question_answering/input.py index 70c1dd7239..541c451e31 100644 --- a/flash/text/question_answering/input.py +++ b/flash/text/question_answering/input.py @@ -77,7 +77,7 @@ def load_data( if flash._IS_TESTING: # NOTE: must subset in this way to return a Dataset - hf_dataset = hf_dataset.select(range(20)) + hf_dataset = [sample for sample in hf_dataset.select(range(40), keep_in_memory=True)] return hf_dataset diff --git a/flash/text/question_answering/model.py b/flash/text/question_answering/model.py index 0b1b529568..9002b56f23 100644 --- a/flash/text/question_answering/model.py +++ b/flash/text/question_answering/model.py @@ -291,13 +291,11 @@ def common_step(self, prefix: str, batch: Any) -> torch.Tensor: self.log_dict(result, on_step=False, on_epoch=True, prog_bar=False) def compute_metrics(self, generated_tokens, batch): - for example in batch: - predicted_answer = generated_tokens[example["example_id"]] - target_answer = example["answer"]["text"][0] if len(example["answer"]["text"]) > 0 else "" - self.rouge.update(predicted_answer, target_answer) - - result = self.rouge.compute() - return result + predicted_answers = [generated_tokens[example["example_id"]] for example in batch] + target_answers = [ + example["answer"]["text"][0] if len(example["answer"]["text"]) > 0 else "" for example in batch + ] + return self.rouge(predicted_answers, target_answers) def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): self.common_step("val", batch) diff --git a/flash/text/seq2seq/core/input.py b/flash/text/seq2seq/core/input.py index 5bdc775558..0f935d49aa 100644 --- a/flash/text/seq2seq/core/input.py +++ b/flash/text/seq2seq/core/input.py @@ -44,7 +44,7 @@ def load_data( if flash._IS_TESTING: # NOTE: must subset in this way to return a Dataset - hf_dataset = hf_dataset.select(range(20)) + hf_dataset = [sample for sample in hf_dataset.select(range(40), keep_in_memory=True)] return hf_dataset diff --git a/requirements/test.txt b/requirements/test.txt index faabbef85f..8b5899f7d3 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -5,6 +5,7 @@ pytest-flake8 flake8 pytest-doctestplus>=0.9.0 pytest-rerunfailures>=10.0 +pytest-forked # install pkg check-manifest diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 3c43e2daa0..f57746c65c 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -16,6 +16,7 @@ from unittest import mock import pytest +import torch from flash.core.utilities.imports import ( _AUDIO_TESTING, @@ -30,6 +31,7 @@ _VIDEO_TESTING, ) from tests.examples.utils import run_test +from tests.helpers.forked import forked root = Path(__file__).parent.parent.parent @@ -82,7 +84,11 @@ marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"), ), pytest.param( - "style_transfer.py", marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") + "style_transfer.py", + marks=[ + pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"), + pytest.mark.skipif(torch.cuda.device_count() >= 2, reason="PyStiche doesn't support DDP"), + ], ), pytest.param( "summarization.py", marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed") @@ -141,5 +147,6 @@ ), ], ) +@forked def test_example(tmpdir, file): run_test(str(root / "flash_examples" / file)) diff --git a/tests/examples/utils.py b/tests/examples/utils.py index 6a8ef4dbb3..cb923db1ce 100644 --- a/tests/examples/utils.py +++ b/tests/examples/utils.py @@ -22,10 +22,13 @@ def call_script( timeout: Optional[int] = 60 * 10, ) -> Tuple[int, str, str]: with open(filepath) as original: - data = original.read() + data = original.readlines() with open(filepath, "w") as modified: - modified.write("import pytorch_lightning as pl\npl.seed_everything(42)\n" + data) + modified.write("import pytorch_lightning as pl\npl.seed_everything(42)\n") + modified.write("if __name__ == '__main__':\n") + for line in data: + modified.write(f" {line}\n") if args is None: args = [] @@ -42,7 +45,7 @@ def call_script( stderr = stderr.decode("utf-8") with open(filepath, "w") as modified: - modified.write(data) + modified.writelines(data) return p.returncode, stdout, stderr diff --git a/tests/helpers/forked.py b/tests/helpers/forked.py new file mode 100644 index 0000000000..22680f3291 --- /dev/null +++ b/tests/helpers/forked.py @@ -0,0 +1,23 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest + + +def forked(callable): + # PyTest forked not available in Windows + if os.name == "nt": + return callable + return pytest.mark.forked(callable)