diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a4ed7a56d0..eb784df995c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,8 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added Learned Perceptual Image Patch Similarity (LPIPS) ([#431](https://github.com/PyTorchLightning/metrics/issues/431)) + - Added support for float targets in `nDCG` metric ([#437](https://github.com/PyTorchLightning/metrics/pull/437)) + ### Changed diff --git a/docs/source/links.rst b/docs/source/links.rst index dea08edf960..acc947c131c 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -38,3 +38,4 @@ .. _MAPE implementation returns: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_percentage_error.html .. _mean squared logarithmic error: https://scikit-learn.org/stable/modules/model_evaluation.html#mean-squared-log-error .. _Mean Reciprocal Rank: https://en.wikipedia.org/wiki/Mean_reciprocal_rank +.. _LPIPS: https://arxiv.org/abs/1801.03924 diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index f103a78ee57..63ea6f5c32b 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -335,6 +335,11 @@ KID .. autoclass:: torchmetrics.KID :noindex: +LPIPS +~~~~~ + +.. autoclass:: torchmetrics.LPIPS + :noindex: PSNR ~~~~ @@ -342,7 +347,6 @@ PSNR .. autoclass:: torchmetrics.PSNR :noindex: - SSIM ~~~~ diff --git a/requirements/image.txt b/requirements/image.txt index 462520fd6b6..f2604f26706 100644 --- a/requirements/image.txt +++ b/requirements/image.txt @@ -1,3 +1,4 @@ scipy torchvision # this is needed to internally set TV version according installed PT torch-fidelity +lpips diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 6812744d2e1..6304ebffc0a 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -15,7 +15,7 @@ import pickle import sys from functools import partial -from typing import Any, Callable, Sequence +from typing import Any, Callable, Optional, Sequence import numpy as np import pytest @@ -61,7 +61,7 @@ def _assert_allclose(pl_result: Any, sk_result: Any, atol: float = 1e-8): """Utility function for recursively asserting that two results are within a certain tolerance.""" # single output compare if isinstance(pl_result, Tensor): - assert np.allclose(pl_result.cpu().numpy(), sk_result, atol=atol, equal_nan=True) + assert np.allclose(pl_result.detach().cpu().numpy(), sk_result, atol=atol, equal_nan=True) # multi output compare elif isinstance(pl_result, Sequence): for pl_res, sk_res in zip(pl_result, sk_result): @@ -127,6 +127,9 @@ def _class_test( kwargs_update: Additional keyword arguments that will be passed with preds and target when running update on the metric. """ + assert preds.shape[0] == target.shape[0] + num_batches = preds.shape[0] + if not metric_args: metric_args = {} @@ -149,7 +152,7 @@ def _class_test( pickled_metric = pickle.dumps(metric) metric = pickle.loads(pickled_metric) - for i in range(rank, NUM_BATCHES, worldsize): + for i in range(rank, num_batches, worldsize): batch_kwargs_update = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} batch_result = metric(preds[i], target[i], **batch_kwargs_update) @@ -177,10 +180,10 @@ def _class_test( result = metric.compute() _assert_tensor(result) - total_preds = torch.cat([preds[i] for i in range(NUM_BATCHES)]).cpu() - total_target = torch.cat([target[i] for i in range(NUM_BATCHES)]).cpu() + total_preds = torch.cat([preds[i] for i in range(num_batches)]).cpu() + total_target = torch.cat([target[i] for i in range(num_batches)]).cpu() total_kwargs_update = { - k: torch.cat([v[i] for i in range(NUM_BATCHES)]).cpu() if isinstance(v, Tensor) else v + k: torch.cat([v[i] for i in range(num_batches)]).cpu() if isinstance(v, Tensor) else v for k, v in kwargs_update.items() } sk_result = sk_metric(total_preds, total_target, **total_kwargs_update) @@ -213,6 +216,9 @@ def _functional_test( kwargs_update: Additional keyword arguments that will be passed with preds and target when running update on the metric. """ + assert preds.shape[0] == target.shape[0] + num_batches = preds.shape[0] + if not metric_args: metric_args = {} @@ -223,7 +229,7 @@ def _functional_test( target = target.to(device) kwargs_update = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} - for i in range(NUM_BATCHES): + for i in range(num_batches): extra_kwargs = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} lightning_result = metric(preds[i], target[i], **extra_kwargs) extra_kwargs = { @@ -238,7 +244,7 @@ def _functional_test( def _assert_half_support( metric_module: Metric, - metric_functional: Callable, + metric_functional: Optional[Callable], preds: Tensor, target: Tensor, device: str = "cpu", @@ -263,7 +269,8 @@ def _assert_half_support( } metric_module = metric_module.to(device) _assert_tensor(metric_module(y_hat, y, **kwargs_update)) - _assert_tensor(metric_functional(y_hat, y, **kwargs_update)) + if metric_functional is not None: + _assert_tensor(metric_functional(y_hat, y, **kwargs_update)) class MetricTester: @@ -411,8 +418,8 @@ def run_precision_test_cpu( preds: Tensor, target: Tensor, metric_module: Metric, - metric_functional: Callable, - metric_args: dict = None, + metric_functional: Optional[Callable] = None, + metric_args: Optional[dict] = None, **kwargs_update, ): """Test if a metric can be used with half precision tensors on cpu @@ -435,8 +442,8 @@ def run_precision_test_gpu( preds: Tensor, target: Tensor, metric_module: Metric, - metric_functional: Callable, - metric_args: dict = None, + metric_functional: Optional[Callable] = None, + metric_args: Optional[dict] = None, **kwargs_update, ): """Test if a metric can be used with half precision tensors on gpu @@ -459,8 +466,8 @@ def run_differentiability_test( preds: Tensor, target: Tensor, metric_module: Metric, - metric_functional: Callable, - metric_args: dict = None, + metric_functional: Optional[Callable] = None, + metric_args: Optional[dict] = None, ): """Test if a metric is differentiable or not. @@ -480,7 +487,7 @@ def run_differentiability_test( # Check if requires_grad matches is_differentiable attribute _assert_requires_grad(metric, out) - if metric.is_differentiable: + if metric.is_differentiable and metric_functional is not None: # check for numerical correctness assert torch.autograd.gradcheck( partial(metric_functional, **metric_args), (preds[0].double(), target[0]) diff --git a/tests/image/test_lpips.py b/tests/image/test_lpips.py new file mode 100644 index 00000000000..5a51bef987f --- /dev/null +++ b/tests/image/test_lpips.py @@ -0,0 +1,103 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial + +import pytest +import torch +from lpips import LPIPS as reference_LPIPS +from torch import Tensor + +from tests.helpers import seed_all +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from torchmetrics.image.lpip_similarity import LPIPS +from torchmetrics.utilities.imports import _LPIPS_AVAILABLE + +seed_all(42) + +Input = namedtuple("Input", ["img1", "img2"]) + +_inputs = Input( + img1=torch.rand(int(NUM_BATCHES * 0.4), int(BATCH_SIZE / 16), 3, 100, 100), + img2=torch.rand(int(NUM_BATCHES * 0.4), int(BATCH_SIZE / 16), 3, 100, 100), +) + + +def _compare_fn(img1: Tensor, img2: Tensor, net_type: str, reduction: str = "mean") -> Tensor: + """comparison function for tm implementation.""" + ref = reference_LPIPS(net=net_type) + res = ref(img1, img2).detach().cpu().numpy() + if reduction == "mean": + return res.mean() + return res.sum() + + +@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed") +@pytest.mark.parametrize("net_type", ["vgg", "alex", "squeeze"]) +class TestLPIPS(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + def test_lpips(self, net_type, ddp): + """test modular implementation for correctness.""" + self.run_class_metric_test( + ddp=ddp, + preds=_inputs.img1, + target=_inputs.img2, + metric_class=LPIPS, + sk_metric=partial(_compare_fn, net_type=net_type), + dist_sync_on_step=False, + check_scriptable=False, + metric_args={"net_type": net_type}, + ) + + def test_lpips_differentiability(self, net_type): + """test for differentiability of LPIPS metric.""" + self.run_differentiability_test(preds=_inputs.img1, target=_inputs.img2, metric_module=LPIPS) + + # LPIPS half + cpu does not work due to missing support in torch.min + @pytest.mark.xfail(reason="PearsonCorrcoef metric does not support cpu + half precision") + def test_lpips_half_cpu(self, net_type): + """test for half + cpu support.""" + self.run_precision_test_cpu(_inputs.img1, _inputs.img2, LPIPS) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + def test_lpips_half_gpu(self, net_type): + """test for half + gpu support.""" + self.run_precision_test_gpu(_inputs.img1, _inputs.img2, LPIPS) + + +@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed") +def test_error_on_wrong_init(): + """Test class raises the expected errors.""" + with pytest.raises(ValueError, match="Argument `net_type` must be one .*"): + LPIPS(net_type="resnet") + + with pytest.raises(ValueError, match="Argument `reduction` must be one .*"): + LPIPS(reduction=None) + + +@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed") +@pytest.mark.parametrize( + "inp1, inp2", + [ + (torch.rand(1, 1, 28, 28), torch.rand(1, 3, 28, 28)), # wrong number of channels + (torch.rand(1, 3, 28, 28), torch.rand(1, 1, 28, 28)), # wrong number of channels + (torch.randn(1, 3, 28, 28), torch.rand(1, 3, 28, 28)), # non-normalized input + (torch.rand(1, 3, 28, 28), torch.randn(1, 3, 28, 28)), # non-normalized input + ], +) +def test_error_on_wrong_update(inp1, inp2): + """test error is raised on wrong input to update method.""" + metric = LPIPS() + with pytest.raises(ValueError, match="Expected both input arguments to be normalized tensors .*"): + metric(inp1, inp2) diff --git a/tests/wrappers/test_bootstrapping.py b/tests/wrappers/test_bootstrapping.py index f7e6bfb22f8..a85b04d4cb1 100644 --- a/tests/wrappers/test_bootstrapping.py +++ b/tests/wrappers/test_bootstrapping.py @@ -19,11 +19,14 @@ from sklearn.metrics import precision_score, recall_score from torch import Tensor +from tests.helpers import seed_all from torchmetrics.classification import Precision, Recall from torchmetrics.utilities import apply_to_collection from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_7 from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler +seed_all(42) + _preds = torch.randint(10, (10, 32)) _target = torch.randint(10, (10, 32)) @@ -55,10 +58,10 @@ def _sample_checker(old_samples, new_samples, op: operator, threshold: int): @pytest.mark.parametrize("sampling_strategy", ["poisson", "multinomial"]) def test_bootstrap_sampler(sampling_strategy): """make sure that the bootstrap sampler works as intended.""" - old_samples = torch.randn(10, 2) + old_samples = torch.randn(20, 2) # make sure that the new samples are only made up of old samples - idx = _bootstrap_sampler(10, sampling_strategy=sampling_strategy) + idx = _bootstrap_sampler(20, sampling_strategy=sampling_strategy) new_samples = old_samples[idx] for ns in new_samples: assert ns in old_samples diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 9f9dd038235..660e4e591cd 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -40,7 +40,7 @@ StatScores, ) from torchmetrics.collections import MetricCollection # noqa: E402 -from torchmetrics.image import FID, IS, KID, PSNR, SSIM # noqa: E402 +from torchmetrics.image import FID, IS, KID, LPIPS, PSNR, SSIM # noqa: E402 from torchmetrics.metric import Metric # noqa: E402 from torchmetrics.regression import ( # noqa: E402 CosineSimilarity, @@ -92,6 +92,7 @@ "IS", "KID", "KLDivergence", + "LPIPS", "MatthewsCorrcoef", "MeanAbsoluteError", "MeanAbsolutePercentageError", diff --git a/torchmetrics/image/__init__.py b/torchmetrics/image/__init__.py index 6070d3e1298..8ee5d0c5107 100644 --- a/torchmetrics/image/__init__.py +++ b/torchmetrics/image/__init__.py @@ -14,5 +14,6 @@ from torchmetrics.image.fid import FID # noqa: F401 from torchmetrics.image.inception import IS # noqa: F401 from torchmetrics.image.kid import KID # noqa: F401 +from torchmetrics.image.lpip_similarity import LPIPS # noqa: F401 from torchmetrics.image.psnr import PSNR # noqa: F401 from torchmetrics.image.ssim import SSIM # noqa: F401 diff --git a/torchmetrics/image/lpip_similarity.py b/torchmetrics/image/lpip_similarity.py new file mode 100644 index 00000000000..52418187059 --- /dev/null +++ b/torchmetrics/image/lpip_similarity.py @@ -0,0 +1,159 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, List, Optional + +import torch +from torch import Tensor + +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _LPIPS_AVAILABLE + +if _LPIPS_AVAILABLE: + from lpips import LPIPS as Lpips_backbone +else: + + class Lpips_backbone(torch.nn.Module): # type: ignore + pass + + +class NoTrainLpips(Lpips_backbone): + def train(self, mode: bool) -> "NoTrainLpips": + """the network should not be able to be switched away from evaluation mode.""" + return super().train(False) + + +def _valid_img(img: Tensor) -> bool: + """check that input is a valid image to the network.""" + return img.ndim == 4 and img.shape[1] == 3 and img.min() >= -1.0 and img.max() <= 1.0 + + +class LPIPS(Metric): + """The Learned Perceptual Image Patch Similarity (`LPIPS_`) is used to judge the perceptual similarity between + two images. LPIPS essentially computes the similarity between the activations of two image patches for some + pre-defined network. This measure have been shown to match human perseption well. A low LPIPS score means that + image patches are perceptual similar. + + Both input image patches are expected to have shape `[N, 3, H, W]` and be normalized to the [-1,1] + range. The minimum size of `H, W` depends on the chosen backbone (see `net_type` arg). + + .. note:: using this metrics requires you to have ``lpips`` package installed. Either install + as ``pip install torchmetrics[image]`` or ``pip install lpips`` + + .. note:: this metric is not scriptable when using ``torch<1.8``. Please update your pytorch installation + if this is a issue. + + Args: + net_type: str indicating backbone network type to use. Choose between `'alex'`, `'vgg'` or `'squeeze'` + reduction: str indicating how to reduce over the batch dimension. Choose between `'sum'` or `'mean'`. + compute_on_step: + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step + process_group: + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather + + Raises: + ValueError: + If ``lpips`` package is not installed + ValueError: + If ``net_type`` is not one of ``"vgg"``, ``"alex"`` or ``"squeeze"`` + ValueError: + If ``reduction`` is not one of ``"mean"`` or ``"sum"`` + + Example: + >>> import torch + >>> _ = torch.manual_seed(123) + >>> from torchmetrics import LPIPS + >>> lpips = LPIPS(net_type='vgg') + >>> img1 = torch.rand(10, 3, 100, 100) + >>> img2 = torch.rand(10, 3, 100, 100) + >>> lpips(img1, img2) + tensor([0.3566], grad_fn=) + """ + + real_features: List[Tensor] + fake_features: List[Tensor] + + # due to the use of named tuple in the backbone the net variable cannot be scriptet + __jit_ignored_attributes__ = ["net"] + + def __init__( + self, + net_type: str = "alex", + reduction: str = "mean", + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable[[Tensor], List[Tensor]] = None, + ) -> None: + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + if not _LPIPS_AVAILABLE: + raise ValueError( + "LPIPS metric requires that lpips is installed." + "Either install as `pip install torchmetrics[image]` or `pip install lpips`" + ) + + valid_net_type = ("vgg", "alex", "squeeze") + if net_type not in valid_net_type: + raise ValueError(f"Argument `net_type` must be one of {valid_net_type}, but got {net_type}.") + self.net = NoTrainLpips(net=net_type, verbose=False) + + valid_reduction = ("mean", "sum") + if reduction not in valid_reduction: + raise ValueError(f"Argument `reduction` must be one of {valid_reduction}, but got {reduction}") + self.reduction = reduction + + self.add_state("sum_scores", torch.zeros(1), dist_reduce_fx="sum") + self.add_state("total", torch.zeros(1), dist_reduce_fx="sum") + + def update(self, img1: Tensor, img2: Tensor) -> None: # type: ignore + """Update internal states with lpips score. + + Args: + img1: tensor with images of shape [N, 3, H, W] + img2: tensor with images of shape [N, 3, H, W] + """ + if not (_valid_img(img1) and _valid_img(img2)): + raise ValueError( + "Expected both input arguments to be normalized tensors (all values in range [-1,1])" + f" and to have shape [N, 3, H, W] but `img1` have shape {img1.shape} with values in" + f" range {[img1.min(), img1.max()]} and `img2` have shape {img2.shape} with value" + f" in range {[img2.min(), img2.max()]}" + ) + + loss = self.net(img1, img2).squeeze() + self.sum_scores += loss.sum() + self.total += img1.shape[0] + + def compute(self) -> Tensor: + """Compute final perceptual similarity metric.""" + if self.reduction == "mean": + return self.sum_scores / self.total + elif self.reduction == "sum": + return self.sum_scores + + @property + def is_differentiable(self) -> bool: + return True diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index 15e65fcdc03..36e49376a4c 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -79,3 +79,4 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool] _BERTSCORE_AVAILABLE: bool = _module_available("bert_score") _SCIPY_AVAILABLE: bool = _module_available("scipy") _TORCH_FIDELITY_AVAILABLE: bool = _module_available("torch_fidelity") +_LPIPS_AVAILABLE: bool = _module_available("lpips")