Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add LPIPS #431

Merged
merged 44 commits into from
Aug 18, 2021
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
1d0cdd8
grouping in Chlog
Borda Aug 5, 2021
3c21693
implemention
SkafteNicki Aug 5, 2021
57970ed
init files
SkafteNicki Aug 5, 2021
9089d35
requirements
SkafteNicki Aug 5, 2021
04a79f0
change to optional testing
SkafteNicki Aug 5, 2021
baafb9e
update
SkafteNicki Aug 5, 2021
8095fe6
working test
SkafteNicki Aug 6, 2021
1522fc1
Merge branch 'master' into lpips
SkafteNicki Aug 6, 2021
2b1ffe9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2021
be4d7ba
docs
SkafteNicki Aug 6, 2021
7994d86
Update tests/image/test_lpips.py
SkafteNicki Aug 8, 2021
5660642
fix suggestions
SkafteNicki Aug 8, 2021
15f8c5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 8, 2021
0334ebb
Merge branch 'master' into lpips
SkafteNicki Aug 8, 2021
f5ee662
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 8, 2021
87cd02d
fix
SkafteNicki Aug 8, 2021
dc663de
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 8, 2021
c3596a2
lower cpu load
SkafteNicki Aug 8, 2021
0fa7fd2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 8, 2021
a1f560e
fix
SkafteNicki Aug 8, 2021
2451973
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 8, 2021
b1dd373
Merge branch 'master' into lpips
SkafteNicki Aug 9, 2021
79bd451
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2021
ff5f39f
fix docs
SkafteNicki Aug 9, 2021
0c4d76a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2021
a4d3578
Merge branch 'master' into lpips
Borda Aug 9, 2021
0493b25
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2021
d095ff2
Merge branch 'master' into lpips
Borda Aug 15, 2021
aaef0d2
fix docs
SkafteNicki Aug 16, 2021
abbdd92
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 16, 2021
36c13c9
Merge branch 'master' into lpips
Borda Aug 16, 2021
be4828a
skip testing
SkafteNicki Aug 17, 2021
cf30501
Merge branch 'lpips' of https://github.com/PyTorchLightning/metrics i…
SkafteNicki Aug 17, 2021
86b92fb
Merge branch 'master' into lpips
SkafteNicki Aug 17, 2021
05c4685
Merge branch 'master' into lpips
mergify[bot] Aug 17, 2021
d9a40e1
Merge branch 'master' into lpips
mergify[bot] Aug 17, 2021
0e9bd77
Merge branch 'master' into lpips
mergify[bot] Aug 17, 2021
8cd29f4
Merge branch 'master' into lpips
mergify[bot] Aug 17, 2021
28cb524
add seed
SkafteNicki Aug 17, 2021
af225e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2021
50173fb
fix
SkafteNicki Aug 17, 2021
77a4f6b
Merge branch 'lpips' of https://github.com/PyTorchLightning/metrics i…
SkafteNicki Aug 17, 2021
8e1ea36
diable scripting
SkafteNicki Aug 18, 2021
f2804d5
pep8
SkafteNicki Aug 18, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Learned Perceptual Image Patch Similarity (LPIPS) ([#431](https://github.com/PyTorchLightning/metrics/issues/431))


## [0.5.0] - 2021-08-09

Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@
.. _MAPE implementation returns: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_percentage_error.html
.. _mean squared logarithmic error: https://scikit-learn.org/stable/modules/model_evaluation.html#mean-squared-log-error
.. _Mean Reciprocal Rank: https://en.wikipedia.org/wiki/Mean_reciprocal_rank
.. _LPIPS: https://arxiv.org/abs/1801.03924
6 changes: 5 additions & 1 deletion docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,14 +335,18 @@ KID
.. autoclass:: torchmetrics.KID
:noindex:

LPIPS
~~~~~

.. autoclass:: torchmetrics.LPIPS
:noindex:

PSNR
~~~~

.. autoclass:: torchmetrics.PSNR
:noindex:


SSIM
~~~~

Expand Down
1 change: 1 addition & 0 deletions requirements/image.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
scipy
torchvision # this is needed to internally set TV version according installed PT
torch-fidelity
lpips
39 changes: 23 additions & 16 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pickle
import sys
from functools import partial
from typing import Any, Callable, Sequence
from typing import Any, Callable, Optional, Sequence

import numpy as np
import pytest
Expand Down Expand Up @@ -61,7 +61,7 @@ def _assert_allclose(pl_result: Any, sk_result: Any, atol: float = 1e-8):
"""Utility function for recursively asserting that two results are within a certain tolerance."""
# single output compare
if isinstance(pl_result, Tensor):
assert np.allclose(pl_result.cpu().numpy(), sk_result, atol=atol, equal_nan=True)
assert np.allclose(pl_result.detach().cpu().numpy(), sk_result, atol=atol, equal_nan=True)
# multi output compare
elif isinstance(pl_result, Sequence):
for pl_res, sk_res in zip(pl_result, sk_result):
Expand Down Expand Up @@ -127,6 +127,9 @@ def _class_test(
kwargs_update: Additional keyword arguments that will be passed with preds and
target when running update on the metric.
"""
assert preds.shape[0] == target.shape[0]
num_batches = preds.shape[0]

if not metric_args:
metric_args = {}

Expand All @@ -149,7 +152,7 @@ def _class_test(
pickled_metric = pickle.dumps(metric)
metric = pickle.loads(pickled_metric)

for i in range(rank, NUM_BATCHES, worldsize):
for i in range(rank, num_batches, worldsize):
batch_kwargs_update = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()}

batch_result = metric(preds[i], target[i], **batch_kwargs_update)
Expand Down Expand Up @@ -177,10 +180,10 @@ def _class_test(
result = metric.compute()
_assert_tensor(result)

total_preds = torch.cat([preds[i] for i in range(NUM_BATCHES)]).cpu()
total_target = torch.cat([target[i] for i in range(NUM_BATCHES)]).cpu()
total_preds = torch.cat([preds[i] for i in range(num_batches)]).cpu()
total_target = torch.cat([target[i] for i in range(num_batches)]).cpu()
total_kwargs_update = {
k: torch.cat([v[i] for i in range(NUM_BATCHES)]).cpu() if isinstance(v, Tensor) else v
k: torch.cat([v[i] for i in range(num_batches)]).cpu() if isinstance(v, Tensor) else v
for k, v in kwargs_update.items()
}
sk_result = sk_metric(total_preds, total_target, **total_kwargs_update)
Expand Down Expand Up @@ -213,6 +216,9 @@ def _functional_test(
kwargs_update: Additional keyword arguments that will be passed with preds and
target when running update on the metric.
"""
assert preds.shape[0] == target.shape[0]
num_batches = preds.shape[0]

if not metric_args:
metric_args = {}

Expand All @@ -223,7 +229,7 @@ def _functional_test(
target = target.to(device)
kwargs_update = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in kwargs_update.items()}

for i in range(NUM_BATCHES):
for i in range(num_batches):
extra_kwargs = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()}
lightning_result = metric(preds[i], target[i], **extra_kwargs)
extra_kwargs = {
Expand All @@ -238,7 +244,7 @@ def _functional_test(

def _assert_half_support(
metric_module: Metric,
metric_functional: Callable,
metric_functional: Optional[Callable],
preds: Tensor,
target: Tensor,
device: str = "cpu",
Expand All @@ -263,7 +269,8 @@ def _assert_half_support(
}
metric_module = metric_module.to(device)
_assert_tensor(metric_module(y_hat, y, **kwargs_update))
_assert_tensor(metric_functional(y_hat, y, **kwargs_update))
if metric_functional is not None:
_assert_tensor(metric_functional(y_hat, y, **kwargs_update))


class MetricTester:
Expand Down Expand Up @@ -411,8 +418,8 @@ def run_precision_test_cpu(
preds: Tensor,
target: Tensor,
metric_module: Metric,
metric_functional: Callable,
metric_args: dict = None,
metric_functional: Optional[Callable] = None,
metric_args: Optional[dict] = None,
**kwargs_update,
):
"""Test if a metric can be used with half precision tensors on cpu
Expand All @@ -435,8 +442,8 @@ def run_precision_test_gpu(
preds: Tensor,
target: Tensor,
metric_module: Metric,
metric_functional: Callable,
metric_args: dict = None,
metric_functional: Optional[Callable] = None,
metric_args: Optional[dict] = None,
**kwargs_update,
):
"""Test if a metric can be used with half precision tensors on gpu
Expand All @@ -459,8 +466,8 @@ def run_differentiability_test(
preds: Tensor,
target: Tensor,
metric_module: Metric,
metric_functional: Callable,
metric_args: dict = None,
metric_functional: Optional[Callable] = None,
metric_args: Optional[dict] = None,
):
"""Test if a metric is differentiable or not.

Expand All @@ -480,7 +487,7 @@ def run_differentiability_test(
# Check if requires_grad matches is_differentiable attribute
_assert_requires_grad(metric, out)

if metric.is_differentiable:
if metric.is_differentiable and metric_functional is not None:
# check for numerical correctness
assert torch.autograd.gradcheck(
partial(metric_functional, **metric_args), (preds[0].double(), target[0])
Expand Down
102 changes: 102 additions & 0 deletions tests/image/test_lpips.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
import torch
from lpips import LPIPS as reference_LPIPS
from torch import Tensor

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.image.lpip_similarity import LPIPS
from torchmetrics.utilities.imports import _LPIPS_AVAILABLE

seed_all(42)

Input = namedtuple("Input", ["img1", "img2"])

_inputs = Input(
img1=torch.rand(int(NUM_BATCHES * 0.4), int(BATCH_SIZE / 16), 3, 100, 100),
img2=torch.rand(int(NUM_BATCHES * 0.4), int(BATCH_SIZE / 16), 3, 100, 100),
)


def _compare_fn(img1: Tensor, img2: Tensor, net_type: str, reduction: str = "mean") -> Tensor:
"""comparison function for tm implementation."""
ref = reference_LPIPS(net=net_type)
res = ref(img1, img2).detach().cpu().numpy()
if reduction == "mean":
return res.mean()
return res.sum()


@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
@pytest.mark.parametrize("net_type", ["vgg", "alex", "squeeze"])
class TestLPIPS(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
def test_lpips(self, net_type, ddp):
"""test modular implementation for correctness."""
self.run_class_metric_test(
ddp=ddp,
preds=_inputs.img1,
target=_inputs.img2,
metric_class=LPIPS,
sk_metric=partial(_compare_fn, net_type=net_type),
dist_sync_on_step=False,
metric_args={"net_type": net_type},
)

def test_lpips_differentiability(self, net_type):
"""test for differentiability of LPIPS metric."""
self.run_differentiability_test(preds=_inputs.img1, target=_inputs.img2, metric_module=LPIPS)

# LPIPS half + cpu does not work due to missing support in torch.min
@pytest.mark.xfail(reason="PearsonCorrcoef metric does not support cpu + half precision")
def test_lpips_half_cpu(self, net_type):
"""test for half + cpu support."""
self.run_precision_test_cpu(_inputs.img1, _inputs.img2, LPIPS)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_lpips_half_gpu(self, net_type):
"""test for half + gpu support."""
self.run_precision_test_gpu(_inputs.img1, _inputs.img2, LPIPS)


@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
def test_error_on_wrong_init():
"""Test class raises the expected errors."""
with pytest.raises(ValueError, match="Argument `net_type` must be one .*"):
LPIPS(net_type="resnet")

with pytest.raises(ValueError, match="Argument `reduction` must be one .*"):
LPIPS(reduction=None)


@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
@pytest.mark.parametrize(
"inp1, inp2",
[
(torch.rand(1, 1, 28, 28), torch.rand(1, 3, 28, 28)), # wrong number of channels
(torch.rand(1, 3, 28, 28), torch.rand(1, 1, 28, 28)), # wrong number of channels
(torch.randn(1, 3, 28, 28), torch.rand(1, 3, 28, 28)), # non-normalized input
(torch.rand(1, 3, 28, 28), torch.randn(1, 3, 28, 28)), # non-normalized input
],
)
def test_error_on_wrong_update(inp1, inp2):
"""test error is raised on wrong input to update method."""
metric = LPIPS()
with pytest.raises(ValueError, match="Expected both input arguments to be normalized tensors .*"):
metric(inp1, inp2)
3 changes: 2 additions & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
StatScores,
)
from torchmetrics.collections import MetricCollection # noqa: E402
from torchmetrics.image import FID, IS, KID, PSNR, SSIM # noqa: E402
from torchmetrics.image import FID, IS, KID, LPIPS, PSNR, SSIM # noqa: E402
from torchmetrics.metric import Metric # noqa: E402
from torchmetrics.regression import ( # noqa: E402
CosineSimilarity,
Expand Down Expand Up @@ -92,6 +92,7 @@
"IS",
"KID",
"KLDivergence",
"LPIPS",
"MatthewsCorrcoef",
"MeanAbsoluteError",
"MeanAbsolutePercentageError",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
from torchmetrics.image.fid import FID # noqa: F401
from torchmetrics.image.inception import IS # noqa: F401
from torchmetrics.image.kid import KID # noqa: F401
from torchmetrics.image.lpip_similarity import LPIPS # noqa: F401
from torchmetrics.image.psnr import PSNR # noqa: F401
from torchmetrics.image.ssim import SSIM # noqa: F401
Loading