Skip to content

Commit

Permalink
[Metrics] Unification of regression (#4166)
Browse files Browse the repository at this point in the history
* moved to utility

* add files

* unify

* add desc

* update

* end of line

* Apply suggestions from code review

Co-authored-by: Justus Schock <[email protected]>

* Apply suggestions from code review

Co-authored-by: Justus Schock <[email protected]>

* add back functional test in new interface

* pep8

* doctest fix

* test name fix

* unify psnr + add class psnr, TODO: psnr test refactor ala mean squared error

* unify psnr

* rm unused code

* pep8

* docs

* unify ssim

* lower tolerance for ssim

* fix import

* pep8

* docs

* flake8

* test smaller images

* trying to fix test

* no ddp test for ssim

* pep8

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Justus Schock <[email protected]>
Co-authored-by: Teddy Koker <[email protected]>
  • Loading branch information
4 people authored Oct 21, 2020
1 parent 546476c commit a937394
Show file tree
Hide file tree
Showing 28 changed files with 1,039 additions and 623 deletions.
60 changes: 37 additions & 23 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ The example below shows how to use a metric in your ``LightningModule``:
def __init__(self):
...
self.accuracy = pl.metrics.Accuracy()
def training_step(self, batch, batch_idx):
logits = self(x)
...
# log step metric
self.log('train_acc_step', self.accuracy(logits, y))
...
def training_epoch_end(self, outs):
# log epoch metric
self.log('train_acc_epoch', self.accuracy.compute())
Expand All @@ -57,15 +57,15 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
This however is only true for metrics that inherit the base class ``Metric``,
and thus the functional metric API provides no support for in-built distributed synchronization
or reduction functions.


.. code-block:: python
def __init__(self):
...
self.train_acc = pl.metrics.Accuracy()
self.valid_acc = pl.metrics.Accuracy()
def training_step(self, batch, batch_idx):
logits = self(x)
...
Expand All @@ -91,17 +91,17 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
for epoch in range(epochs):
for x, y in train_data:
y_hat = model(x)
# training step accuracy
batch_acc = train_accuracy(y_hat, y)
for x, y in valid_data:
y_hat = model(x)
valid_accuracy(y_hat, y)
# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()
# total accuracy over all validation batches
total_valid_accuracy = valid_accuracy.compute()
Expand Down Expand Up @@ -212,6 +212,20 @@ ExplainedVariance
.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance
:noindex:


PSNR
~~~~

.. autoclass:: pytorch_lightning.metrics.regression.PSNR
:noindex:


SSIM
~~~~

.. autoclass:: pytorch_lightning.metrics.regression.SSIM
:noindex:

******************
Functional Metrics
******************
Expand Down Expand Up @@ -360,45 +374,45 @@ to_onehot [func]
Regression
----------

mae [func]
~~~~~~~~~~
explained_variance [func]
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.regression.mae
.. autofunction:: pytorch_lightning.metrics.functional.explained_variance
:noindex:


mse [func]
~~~~~~~~~~
mean_absolute_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.regression.mse
.. autofunction:: pytorch_lightning.metrics.functional.mean_absolute_error
:noindex:


psnr [func]
~~~~~~~~~~~
mean_squared_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.regression.psnr
.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_error
:noindex:


rmse [func]
psnr [func]
~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.regression.rmse
.. autofunction:: pytorch_lightning.metrics.functional.psnr
:noindex:


rmsle [func]
~~~~~~~~~~~~
mean_squared_log_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.regression.rmsle
.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_log_error
:noindex:


ssim [func]
~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.regression.mae
.. autofunction:: pytorch_lightning.metrics.functional.ssim
:noindex:


Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@
MeanAbsoluteError,
MeanSquaredLogError,
ExplainedVariance,
PSNR,
SSIM,
)
15 changes: 7 additions & 8 deletions pytorch_lightning/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,13 @@
iou,
)
from pytorch_lightning.metrics.functional.nlp import bleu_score
from pytorch_lightning.metrics.functional.regression import (
mae,
mse,
psnr,
rmse,
rmsle,
ssim
)
from pytorch_lightning.metrics.functional.self_supervised import (
embedding_similarity
)
# TODO: unify metrics between class and functional, add below
from pytorch_lightning.metrics.functional.explained_variance import explained_variance
from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error
from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error
from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error
from pytorch_lightning.metrics.functional.psnr import psnr
from pytorch_lightning.metrics.functional.ssim import ssim
85 changes: 85 additions & 0 deletions pytorch_lightning/metrics/functional/explained_variance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, Tuple, Sequence

import torch
from pytorch_lightning.metrics.utils import _check_same_shape


def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
_check_same_shape(preds, target)
return preds, target


def _explained_variance_compute(preds: torch.Tensor,
target: torch.Tensor,
multioutput: str = 'uniform_average',
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
diff_avg = torch.mean(target - preds, dim=0)
numerator = torch.mean((target - preds - diff_avg) ** 2, dim=0)

target_avg = torch.mean(target, dim=0)
denominator = torch.mean((target - target_avg) ** 2, dim=0)

# Take care of division by zero
nonzero_numerator = numerator != 0
nonzero_denominator = denominator != 0
valid_score = nonzero_numerator & nonzero_denominator
output_scores = torch.ones_like(diff_avg)
output_scores[valid_score] = 1.0 - (numerator[valid_score] / denominator[valid_score])
output_scores[nonzero_numerator & ~nonzero_denominator] = 0.

# Decide what to do in multioutput case
# Todo: allow user to pass in tensor with weights
if multioutput == 'raw_values':
return output_scores
if multioutput == 'uniform_average':
return torch.mean(output_scores)
if multioutput == 'variance_weighted':
denom_sum = torch.sum(denominator)
return torch.sum(denominator / denom_sum * output_scores)


def explained_variance(preds: torch.Tensor,
target: torch.Tensor,
multioutput: str = 'uniform_average',
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
"""
Computes explained variance.
Args:
pred: estimated labels
target: ground truth labels
multioutput: Defines aggregation in the case of multiple output scores. Can be one
of the following strings (default is `'uniform_average'`.):
* `'raw_values'` returns full set of scores
* `'uniform_average'` scores are uniformly averaged
* `'variance_weighted'` scores are weighted by their individual variances
Example:
>>> from pytorch_lightning.metrics.functional import explained_variance
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> explained_variance(preds, target)
tensor(0.9572)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> explained_variance(preds, target, multioutput='raw_values')
tensor([0.9677, 1.0000])
"""
preds, target = _explained_variance_update(preds, target)
return _explained_variance_compute(preds, target, multioutput)
51 changes: 51 additions & 0 deletions pytorch_lightning/metrics/functional/mean_absolute_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import torch
from pytorch_lightning.metrics.utils import _check_same_shape


def _mean_absolute_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
_check_same_shape(preds, target)
sum_abs_error = torch.sum(torch.abs(preds - target))
n_obs = target.numel()
return sum_abs_error, n_obs


def _mean_absolute_error_compute(sum_abs_error: torch.Tensor, n_obs: int) -> torch.Tensor:
return sum_abs_error / n_obs


def mean_absolute_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Computes mean absolute error
Args:
pred: estimated labels
target: ground truth labels
Return:
Tensor with MAE
Example:
>>> x = torch.tensor([0., 1, 2, 3])
>>> y = torch.tensor([0., 1, 2, 2])
>>> mean_absolute_error(x, y)
tensor(0.2500)
"""
sum_abs_error, n_obs = _mean_absolute_error_update(preds, target)
return _mean_absolute_error_compute(sum_abs_error, n_obs)
51 changes: 51 additions & 0 deletions pytorch_lightning/metrics/functional/mean_squared_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import torch
from pytorch_lightning.metrics.utils import _check_same_shape


def _mean_squared_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
_check_same_shape(preds, target)
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
n_obs = target.numel()
return sum_squared_error, n_obs


def _mean_squared_error_compute(sum_squared_error: torch.Tensor, n_obs: int) -> torch.Tensor:
return sum_squared_error / n_obs


def mean_squared_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Computes mean squared error
Args:
pred: estimated labels
target: ground truth labels
Return:
Tensor with MSE
Example:
>>> x = torch.tensor([0., 1, 2, 3])
>>> y = torch.tensor([0., 1, 2, 2])
>>> mean_squared_error(x, y)
tensor(0.2500)
"""
sum_squared_error, n_obs = _mean_squared_error_update(preds, target)
return _mean_squared_error_compute(sum_squared_error, n_obs)
Loading

0 comments on commit a937394

Please sign in to comment.