-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Metrics] Unification of regression (#4166)
* moved to utility * add files * unify * add desc * update * end of line * Apply suggestions from code review Co-authored-by: Justus Schock <[email protected]> * Apply suggestions from code review Co-authored-by: Justus Schock <[email protected]> * add back functional test in new interface * pep8 * doctest fix * test name fix * unify psnr + add class psnr, TODO: psnr test refactor ala mean squared error * unify psnr * rm unused code * pep8 * docs * unify ssim * lower tolerance for ssim * fix import * pep8 * docs * flake8 * test smaller images * trying to fix test * no ddp test for ssim * pep8 Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Teddy Koker <[email protected]>
- Loading branch information
1 parent
546476c
commit a937394
Showing
28 changed files
with
1,039 additions
and
623 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,4 +25,6 @@ | |
MeanAbsoluteError, | ||
MeanSquaredLogError, | ||
ExplainedVariance, | ||
PSNR, | ||
SSIM, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
85 changes: 85 additions & 0 deletions
85
pytorch_lightning/metrics/functional/explained_variance.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Union, Tuple, Sequence | ||
|
||
import torch | ||
from pytorch_lightning.metrics.utils import _check_same_shape | ||
|
||
|
||
def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
_check_same_shape(preds, target) | ||
return preds, target | ||
|
||
|
||
def _explained_variance_compute(preds: torch.Tensor, | ||
target: torch.Tensor, | ||
multioutput: str = 'uniform_average', | ||
) -> Union[torch.Tensor, Sequence[torch.Tensor]]: | ||
diff_avg = torch.mean(target - preds, dim=0) | ||
numerator = torch.mean((target - preds - diff_avg) ** 2, dim=0) | ||
|
||
target_avg = torch.mean(target, dim=0) | ||
denominator = torch.mean((target - target_avg) ** 2, dim=0) | ||
|
||
# Take care of division by zero | ||
nonzero_numerator = numerator != 0 | ||
nonzero_denominator = denominator != 0 | ||
valid_score = nonzero_numerator & nonzero_denominator | ||
output_scores = torch.ones_like(diff_avg) | ||
output_scores[valid_score] = 1.0 - (numerator[valid_score] / denominator[valid_score]) | ||
output_scores[nonzero_numerator & ~nonzero_denominator] = 0. | ||
|
||
# Decide what to do in multioutput case | ||
# Todo: allow user to pass in tensor with weights | ||
if multioutput == 'raw_values': | ||
return output_scores | ||
if multioutput == 'uniform_average': | ||
return torch.mean(output_scores) | ||
if multioutput == 'variance_weighted': | ||
denom_sum = torch.sum(denominator) | ||
return torch.sum(denominator / denom_sum * output_scores) | ||
|
||
|
||
def explained_variance(preds: torch.Tensor, | ||
target: torch.Tensor, | ||
multioutput: str = 'uniform_average', | ||
) -> Union[torch.Tensor, Sequence[torch.Tensor]]: | ||
""" | ||
Computes explained variance. | ||
Args: | ||
pred: estimated labels | ||
target: ground truth labels | ||
multioutput: Defines aggregation in the case of multiple output scores. Can be one | ||
of the following strings (default is `'uniform_average'`.): | ||
* `'raw_values'` returns full set of scores | ||
* `'uniform_average'` scores are uniformly averaged | ||
* `'variance_weighted'` scores are weighted by their individual variances | ||
Example: | ||
>>> from pytorch_lightning.metrics.functional import explained_variance | ||
>>> target = torch.tensor([3, -0.5, 2, 7]) | ||
>>> preds = torch.tensor([2.5, 0.0, 2, 8]) | ||
>>> explained_variance(preds, target) | ||
tensor(0.9572) | ||
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) | ||
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) | ||
>>> explained_variance(preds, target, multioutput='raw_values') | ||
tensor([0.9677, 1.0000]) | ||
""" | ||
preds, target = _explained_variance_update(preds, target) | ||
return _explained_variance_compute(preds, target, multioutput) |
51 changes: 51 additions & 0 deletions
51
pytorch_lightning/metrics/functional/mean_absolute_error.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Tuple | ||
|
||
import torch | ||
from pytorch_lightning.metrics.utils import _check_same_shape | ||
|
||
|
||
def _mean_absolute_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: | ||
_check_same_shape(preds, target) | ||
sum_abs_error = torch.sum(torch.abs(preds - target)) | ||
n_obs = target.numel() | ||
return sum_abs_error, n_obs | ||
|
||
|
||
def _mean_absolute_error_compute(sum_abs_error: torch.Tensor, n_obs: int) -> torch.Tensor: | ||
return sum_abs_error / n_obs | ||
|
||
|
||
def mean_absolute_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Computes mean absolute error | ||
Args: | ||
pred: estimated labels | ||
target: ground truth labels | ||
Return: | ||
Tensor with MAE | ||
Example: | ||
>>> x = torch.tensor([0., 1, 2, 3]) | ||
>>> y = torch.tensor([0., 1, 2, 2]) | ||
>>> mean_absolute_error(x, y) | ||
tensor(0.2500) | ||
""" | ||
sum_abs_error, n_obs = _mean_absolute_error_update(preds, target) | ||
return _mean_absolute_error_compute(sum_abs_error, n_obs) |
51 changes: 51 additions & 0 deletions
51
pytorch_lightning/metrics/functional/mean_squared_error.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Tuple | ||
|
||
import torch | ||
from pytorch_lightning.metrics.utils import _check_same_shape | ||
|
||
|
||
def _mean_squared_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: | ||
_check_same_shape(preds, target) | ||
sum_squared_error = torch.sum(torch.pow(preds - target, 2)) | ||
n_obs = target.numel() | ||
return sum_squared_error, n_obs | ||
|
||
|
||
def _mean_squared_error_compute(sum_squared_error: torch.Tensor, n_obs: int) -> torch.Tensor: | ||
return sum_squared_error / n_obs | ||
|
||
|
||
def mean_squared_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Computes mean squared error | ||
Args: | ||
pred: estimated labels | ||
target: ground truth labels | ||
Return: | ||
Tensor with MSE | ||
Example: | ||
>>> x = torch.tensor([0., 1, 2, 3]) | ||
>>> y = torch.tensor([0., 1, 2, 2]) | ||
>>> mean_squared_error(x, y) | ||
tensor(0.2500) | ||
""" | ||
sum_squared_error, n_obs = _mean_squared_error_update(preds, target) | ||
return _mean_squared_error_compute(sum_squared_error, n_obs) |
Oops, something went wrong.