Skip to content

Commit

Permalink
Add dim to pytorch_lightning.metrics.PSNR (#5957)
Browse files Browse the repository at this point in the history
* Add dim to PSNR

* Update CHANGELOG.md

* Update CHANGELOG.md

Co-authored-by: Jirka Borovec <[email protected]>

* Add reduction tests

* Recover warnings on reduction and add tests

* Add copyright texts

* Refactor PSNR

* Change warnings

* Update pytorch_lightning/metrics/functional/psnr.py

Change functional.psnr dim doc

Co-authored-by: Nicki Skafte <[email protected]>

* Change PSNR dim docs

* Apply suggestions from code review

* tests

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Nicki Skafte <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
4 people authored Feb 17, 2021
1 parent 7aae589 commit 6a9cec4
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 54 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `LightningModule.configure_callbacks` to enable the definition of model-specific callbacks ([#5621](https://github.com/PyTorchLightning/pytorch-lightning/pull/5621))


- Added `dim` to `PSNR` metric for mean-squared-error reduction ([#5957](https://github.com/PyTorchLightning/pytorch-lightning/pull/5957))


- Added promxial policy optimization template to pl_examples ([#5394](https://github.com/PyTorchLightning/pytorch-lightning/pull/5394))


Expand Down
69 changes: 56 additions & 13 deletions pytorch_lightning/metrics/functional/psnr.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,56 @@
from typing import Optional, Tuple
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union

import torch

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning import utilities
from pytorch_lightning.metrics import utils


def _psnr_compute(
sum_squared_error: torch.Tensor,
n_obs: int,
data_range: float,
n_obs: torch.Tensor,
data_range: torch.Tensor,
base: float = 10.0,
reduction: str = 'elementwise_mean',
) -> torch.Tensor:
if reduction != 'elementwise_mean':
rank_zero_warn(f'The `reduction={reduction}` parameter is unused and will not have any effect.')
psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs)
psnr = psnr_base_e * (10 / torch.log(torch.tensor(base)))
return psnr
return utils.reduce(psnr, reduction=reduction)


def _psnr_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
n_obs = target.numel()
def _psnr_update(preds: torch.Tensor,
target: torch.Tensor,
dim: Optional[Union[int, Tuple[int, ...]]] = None) -> Tuple[torch.Tensor, torch.Tensor]:
if dim is None:
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
n_obs = torch.tensor(target.numel(), device=target.device)
return sum_squared_error, n_obs

sum_squared_error = torch.sum(torch.pow(preds - target, 2), dim=dim)

if isinstance(dim, int):
dim_list = [dim]
else:
dim_list = list(dim)
if not dim_list:
n_obs = torch.tensor(target.numel(), device=target.device)
else:
n_obs = torch.tensor(target.size(), device=target.device)[dim_list].prod()
n_obs = n_obs.expand_as(sum_squared_error)

return sum_squared_error, n_obs


Expand All @@ -31,21 +60,27 @@ def psnr(
data_range: Optional[float] = None,
base: float = 10.0,
reduction: str = 'elementwise_mean',
dim: Optional[Union[int, Tuple[int, ...]]] = None,
) -> torch.Tensor:
"""
Computes the peak signal-to-noise ratio
Args:
preds: estimated signal
target: groun truth signal
data_range: the range of the data. If None, it is determined from the data (max - min)
data_range:
the range of the data. If None, it is determined from the data (max - min). ``data_range`` must be given
when ``dim`` is not None.
base: a base of a logarithm to use (default: 10)
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
dim:
Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is
None meaning scores will be reduced across all dimensions.
Return:
Tensor with PSNR score
Expand All @@ -57,9 +92,17 @@ def psnr(
tensor(2.5527)
"""
if dim is None and reduction != 'elementwise_mean':
utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.')

if data_range is None:
if dim is not None:
# Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to calculate
# `data_range` in the future.
raise ValueError("The `data_range` must be given when `dim` is not None.")

data_range = target.max() - target.min()
else:
data_range = torch.tensor(float(data_range))
sum_squared_error, n_obs = _psnr_update(preds, target)
return _psnr_compute(sum_squared_error, n_obs, data_range, base, reduction)
sum_squared_error, n_obs = _psnr_update(preds, target, dim=dim)
return _psnr_compute(sum_squared_error, n_obs, data_range, base=base, reduction=reduction)
58 changes: 45 additions & 13 deletions pytorch_lightning/metrics/regression/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Any, Optional, Sequence, Tuple, Union

import torch

from pytorch_lightning import utilities
from pytorch_lightning.metrics.functional.psnr import _psnr_compute, _psnr_update
from pytorch_lightning.metrics.metric import Metric

Expand All @@ -29,14 +30,19 @@ class PSNR(Metric):
<https://en.wikipedia.org/wiki/Mean_squared_error>`_ function.
Args:
data_range: the range of the data. If None, it is determined from the data (max - min)
data_range:
the range of the data. If None, it is determined from the data (max - min).
The ``data_range`` must be given when ``dim`` is not None.
base: a base of a logarithm to use (default: 10)
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
dim:
Dimensions to reduce PSNR scores over, provided as either an integer or a list of integers. Default is
None meaning scores will be reduced across all dimensions and all batches.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Expand All @@ -61,6 +67,7 @@ def __init__(
data_range: Optional[float] = None,
base: float = 10.0,
reduction: str = 'elementwise_mean',
dim: Optional[Union[int, Tuple[int, ...]]] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
Expand All @@ -71,16 +78,30 @@ def __init__(
process_group=process_group,
)

self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
if dim is None and reduction != 'elementwise_mean':
utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.')

if dim is None:
self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
else:
self.add_state("sum_squared_error", default=[])
self.add_state("total", default=[])

if data_range is None:
if dim is not None:
# Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to
# calculate `data_range` in the future.
raise ValueError("The `data_range` must be given when `dim` is not None.")

self.data_range = None
self.add_state("min_target", default=torch.tensor(0.0), dist_reduce_fx=torch.min)
self.add_state("max_target", default=torch.tensor(0.0), dist_reduce_fx=torch.max)
else:
self.register_buffer("data_range", torch.tensor(float(data_range)))
self.base = base
self.reduction = reduction
self.dim = tuple(dim) if isinstance(dim, Sequence) else dim

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Expand All @@ -90,14 +111,18 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
preds: Predictions from model
target: Ground truth values
"""
if self.data_range is None:
# keep track of min and max target values
self.min_target = min(target.min(), self.min_target)
self.max_target = max(target.max(), self.max_target)

sum_squared_error, n_obs = _psnr_update(preds, target)
self.sum_squared_error += sum_squared_error
self.total += n_obs
sum_squared_error, n_obs = _psnr_update(preds, target, dim=self.dim)
if self.dim is None:
if self.data_range is None:
# keep track of min and max target values
self.min_target = min(target.min(), self.min_target)
self.max_target = max(target.max(), self.max_target)

self.sum_squared_error += sum_squared_error
self.total += n_obs
else:
self.sum_squared_error.append(sum_squared_error)
self.total.append(n_obs)

def compute(self):
"""
Expand All @@ -107,4 +132,11 @@ def compute(self):
data_range = self.data_range
else:
data_range = self.max_target - self.min_target
return _psnr_compute(self.sum_squared_error, self.total, data_range, self.base, self.reduction)

if self.dim is None:
sum_squared_error = self.sum_squared_error
total = self.total
else:
sum_squared_error = torch.cat([values.flatten() for values in self.sum_squared_error])
total = torch.cat([values.flatten() for values in self.total])
return _psnr_compute(sum_squared_error, total, data_range, base=self.base, reduction=self.reduction)
109 changes: 81 additions & 28 deletions tests/metrics/regression/test_psnr.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import namedtuple
from functools import partial

Expand All @@ -14,67 +28,106 @@

Input = namedtuple('Input', ["preds", "target"])

_input_size = (NUM_BATCHES, BATCH_SIZE, 32, 32)
_inputs = [
Input(
preds=torch.randint(n_cls_pred, (NUM_BATCHES, BATCH_SIZE), dtype=torch.float),
target=torch.randint(n_cls_target, (NUM_BATCHES, BATCH_SIZE), dtype=torch.float),
preds=torch.randint(n_cls_pred, _input_size, dtype=torch.float),
target=torch.randint(n_cls_target, _input_size, dtype=torch.float),
) for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]
]


def _sk_metric(preds, target, data_range):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range)
def _to_sk_peak_signal_noise_ratio_inputs(value, dim):
value = value.numpy()
batches = value[None] if value.ndim == len(_input_size) - 1 else value

if dim is None:
return [batches]

num_dims = np.size(dim)
if not num_dims:
return batches

inputs = []
for batch in batches:
batch = np.moveaxis(batch, dim, np.arange(-num_dims, 0))
psnr_input_shape = batch.shape[-num_dims:]
inputs.extend(batch.reshape(-1, *psnr_input_shape))
return inputs


def _sk_psnr(preds, target, data_range, reduction, dim):
sk_preds_lists = _to_sk_peak_signal_noise_ratio_inputs(preds, dim=dim)
sk_target_lists = _to_sk_peak_signal_noise_ratio_inputs(target, dim=dim)
np_reduce_map = {"elementwise_mean": np.mean, "none": np.array, "sum": np.sum}
return np_reduce_map[reduction]([
peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range)
for sk_target, sk_preds in zip(sk_target_lists, sk_preds_lists)
])


def _base_e_sk_metric(preds, target, data_range):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range) * np.log(10)
def _base_e_sk_psnr(preds, target, data_range, reduction, dim):
return _sk_psnr(preds, target, data_range, reduction, dim) * np.log(10)


@pytest.mark.parametrize(
"preds, target, data_range",
"preds, target, data_range, reduction, dim",
[
(_inputs[0].preds, _inputs[0].target, 10),
(_inputs[1].preds, _inputs[1].target, 10),
(_inputs[2].preds, _inputs[2].target, 5),
(_inputs[0].preds, _inputs[0].target, 10, "elementwise_mean", None),
(_inputs[1].preds, _inputs[1].target, 10, "elementwise_mean", None),
(_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", None),
(_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", 1),
(_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", (1, 2)),
(_inputs[2].preds, _inputs[2].target, 5, "sum", (1, 2)),
],
)
@pytest.mark.parametrize(
"base, sk_metric",
[
(10.0, _sk_metric),
(2.718281828459045, _base_e_sk_metric),
(10.0, _sk_psnr),
(2.718281828459045, _base_e_sk_psnr),
],
)
class TestPSNR(MetricTester):

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_psnr(self, preds, target, data_range, base, sk_metric, ddp, dist_sync_on_step):
def test_psnr(self, preds, target, data_range, base, reduction, dim, sk_metric, ddp, dist_sync_on_step):
_args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim}
self.run_class_metric_test(
ddp,
preds,
target,
PSNR,
partial(sk_metric, data_range=data_range),
metric_args={
"data_range": data_range,
"base": base
},
partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim),
metric_args=_args,
dist_sync_on_step=dist_sync_on_step,
)

def test_psnr_functional(self, preds, target, sk_metric, data_range, base):
def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduction, dim):
_args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim}
self.run_functional_metric_test(
preds,
target,
psnr,
partial(sk_metric, data_range=data_range),
metric_args={
"data_range": data_range,
"base": base
},
partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim),
metric_args=_args,
)


@pytest.mark.parametrize("reduction", ["none", "sum"])
def test_reduction_for_dim_none(reduction):
match = f"The `reduction={reduction}` will not have any effect when `dim` is None."
with pytest.warns(UserWarning, match=match):
PSNR(reduction=reduction, dim=None)

with pytest.warns(UserWarning, match=match):
psnr(_inputs[0].preds, _inputs[0].target, reduction=reduction, dim=None)


def test_missing_data_range():
with pytest.raises(ValueError):
PSNR(data_range=None, dim=0)

with pytest.raises(ValueError):
psnr(_inputs[0].preds, _inputs[0].target, data_range=None, dim=0)

0 comments on commit 6a9cec4

Please sign in to comment.