From 371b898d3a8b6f4c6d9203d2a0d58a25cdda7531 Mon Sep 17 00:00:00 2001 From: Paul Louis Date: Sat, 4 Feb 2023 23:08:10 -0500 Subject: [PATCH 01/22] Start out with support for plot() in image metrics --- CHANGELOG.md | 1 + examples/plotting.py | 21 ++++++++- src/torchmetrics/image/d_lambda.py | 64 +++++++++++++++++++++++++- tests/unittests/utilities/test_plot.py | 7 +++ 4 files changed, 91 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa2202d04c9..667f5907acc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added support for plotting of metrics through `.plot()` method ([#1328](https://github.com/Lightning-AI/metrics/pull/1328)) +- Added support for plotting of image metrics through `.plot()` method ([TODO](TODO)) ### Changed diff --git a/examples/plotting.py b/examples/plotting.py index f877e57c721..3b1070178bf 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -78,12 +78,31 @@ def confusion_matrix_example(): return fig, ax -if __name__ == "__main__": +def spectral_distortion_index_example(): + from torchmetrics.image.d_lambda import SpectralDistortionIndex + + p = lambda: torch.rand([16, 3, 16, 16]) + t = lambda: torch.rand([16, 3, 16, 16]) + # plot single value + metric = SpectralDistortionIndex() + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = SpectralDistortionIndex() + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + + +if __name__ == "__main__": metrics_func = { "accuracy": accuracy_example, "mean_squared_error": mean_squared_error_example, "confusion_matrix": confusion_matrix_example, + "spectral_distortion_index": spectral_distortion_index_example, } parser = argparse.ArgumentParser(description="Example script for plotting metrics.") diff --git a/src/torchmetrics/image/d_lambda.py b/src/torchmetrics/image/d_lambda.py index 089f7d5010a..0cb8483d96c 100644 --- a/src/torchmetrics/image/d_lambda.py +++ b/src/torchmetrics/image/d_lambda.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List +from typing import Any, List, Optional, Union, Sequence from torch import Tensor from typing_extensions import Literal @@ -21,6 +21,12 @@ from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["SpectralDistortionIndex.plot"] + class SpectralDistortionIndex(Metric): """Computes Spectral Distortion Index (SpectralDistortionIndex_) also now as D_lambda is used to compare the @@ -60,6 +66,7 @@ class SpectralDistortionIndex(Metric): higher_is_better: bool = True is_differentiable: bool = True full_state_update: bool = False + plot_options = {"lower_bound": 0.0, "upper_bound": 1.0} preds: List[Tensor] target: List[Tensor] @@ -95,3 +102,58 @@ def compute(self) -> Tensor: preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) return _spectral_distortion_index_compute(preds, target, self.p, self.reduction) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + fig: Figure object + ax: Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + Examples: + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import SpectralDistortionIndex + >>> preds = torch.rand([16, 3, 16, 16]) + >>> target = torch.rand([16, 3, 16, 16]) + >>> metric = SpectralDistortionIndex() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple value + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import SpectralDistortionIndex + >>> preds = torch.rand([16, 3, 16, 16]) + >>> target = torch.rand([16, 3, 16, 16]) + >>> metric = SpectralDistortionIndex() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 6c09b082fd2..4bfc6ae2da2 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -18,6 +18,7 @@ import numpy as np import pytest import torch +from torchmetrics.functional.image.d_lambda import spectral_distortion_index from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy from torchmetrics.functional.classification.confusion_matrix import ( @@ -49,6 +50,12 @@ lambda: torch.randint(3, (100,)), id="multiclass and average=None", ), + pytest.param( + partial(spectral_distortion_index), + lambda: torch.rand([16, 3, 16, 16]), + lambda: torch.rand([16, 3, 16, 16]), + id="spectral distortion index", + ), ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) From dc22016b50bab9ce781cbd4a369993434c49c385 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Feb 2023 04:13:01 +0000 Subject: [PATCH 02/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/image/d_lambda.py | 6 ++---- tests/unittests/utilities/test_plot.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/image/d_lambda.py b/src/torchmetrics/image/d_lambda.py index 94f619ae204..db737ee5d2c 100644 --- a/src/torchmetrics/image/d_lambda.py +++ b/src/torchmetrics/image/d_lambda.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Union, Sequence +from typing import Any, List, Optional, Sequence, Union from torch import Tensor from typing_extensions import Literal @@ -20,7 +20,6 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat - from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val @@ -104,7 +103,7 @@ def compute(self) -> Tensor: return _spectral_distortion_index_compute(preds, target, self.p, self.reduction) def plot( - self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -150,7 +149,6 @@ def plot( >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values) - """ val = val or self.compute() fig, ax = plot_single_or_multi_val( diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 4bfc6ae2da2..c182d02881e 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -18,7 +18,6 @@ import numpy as np import pytest import torch -from torchmetrics.functional.image.d_lambda import spectral_distortion_index from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy from torchmetrics.functional.classification.confusion_matrix import ( @@ -26,6 +25,7 @@ multiclass_confusion_matrix, multilabel_confusion_matrix, ) +from torchmetrics.functional.image.d_lambda import spectral_distortion_index from torchmetrics.utilities.plot import plot_confusion_matrix, plot_single_or_multi_val From 4159fde59fcb7419dc5036a1326d8dc861aa06d2 Mon Sep 17 00:00:00 2001 From: Paul Louis Date: Sat, 4 Feb 2023 23:14:56 -0500 Subject: [PATCH 03/22] Complete CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index df017e9b752..52c54d45b6a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for plotting of metrics through `.plot()` method ([#1328](https://github.com/Lightning-AI/metrics/pull/1328)) -- Added support for plotting of image metrics through `.plot()` method ([TODO](TODO)) +- Added support for plotting of image metrics through `.plot()` method ([#1480](https://github.com/Lightning-AI/metrics/pull/1480)) - Added `classes` to output from `MAP` metric ([#1419](https://github.com/Lightning-AI/metrics/pull/1419)) From 644ae1d094e73875611f3a99bab251470a2f9845 Mon Sep 17 00:00:00 2001 From: Paul Louis Date: Sun, 5 Feb 2023 22:26:08 -0500 Subject: [PATCH 04/22] Add plot() support for ergas in image --- examples/plotting.py | 20 +++++++++ src/torchmetrics/image/ergas.py | 60 +++++++++++++++++++++++++- tests/unittests/utilities/test_plot.py | 7 +++ 3 files changed, 86 insertions(+), 1 deletion(-) diff --git a/examples/plotting.py b/examples/plotting.py index 3b1070178bf..dd6aa255eaa 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -97,12 +97,32 @@ def spectral_distortion_index_example(): return fig, ax +def error_relative_global_dimensionless_synthesis(): + from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis + + p = lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) + t = lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) + + # plot single value + metric = ErrorRelativeGlobalDimensionlessSynthesis() + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = ErrorRelativeGlobalDimensionlessSynthesis() + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + + if __name__ == "__main__": metrics_func = { "accuracy": accuracy_example, "mean_squared_error": mean_squared_error_example, "confusion_matrix": confusion_matrix_example, "spectral_distortion_index": spectral_distortion_index_example, + "error_relative_global_dimensionless_synthesis": error_relative_global_dimensionless_synthesis, } parser = argparse.ArgumentParser(description="Example script for plotting metrics.") diff --git a/src/torchmetrics/image/ergas.py b/src/torchmetrics/image/ergas.py index c7423b866ee..a4a75c64b8d 100644 --- a/src/torchmetrics/image/ergas.py +++ b/src/torchmetrics/image/ergas.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Union +from typing import Any, List, Optional, Sequence, Union from torch import Tensor from typing_extensions import Literal @@ -21,6 +21,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["ErrorRelativeGlobalDimensionlessSynthesis.plot"] class ErrorRelativeGlobalDimensionlessSynthesis(Metric): @@ -61,6 +66,7 @@ class ErrorRelativeGlobalDimensionlessSynthesis(Metric): higher_is_better: bool = False is_differentiable: bool = True full_state_update: bool = False + plot_options = {"lower_bound": 0.0, "upper_bound": 1.0} preds: List[Tensor] target: List[Tensor] @@ -94,3 +100,55 @@ def compute(self) -> Tensor: preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) return _ergas_compute(preds, target, self.ratio, self.reduction) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + fig: Figure object + ax: Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + Examples: + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import ErrorRelativeGlobalDimensionlessSynthesis + >>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) + >>> target = preds * 0.75 + >>> metric = ErrorRelativeGlobalDimensionlessSynthesis() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple value + >>> import torch + >>> from torchmetrics import ErrorRelativeGlobalDimensionlessSynthesis + >>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) + >>> target = preds * 0.75 + >>> metric = ErrorRelativeGlobalDimensionlessSynthesis() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index c182d02881e..0813a5b4057 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -26,6 +26,7 @@ multilabel_confusion_matrix, ) from torchmetrics.functional.image.d_lambda import spectral_distortion_index +from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis from torchmetrics.utilities.plot import plot_confusion_matrix, plot_single_or_multi_val @@ -56,6 +57,12 @@ lambda: torch.rand([16, 3, 16, 16]), id="spectral distortion index", ), + pytest.param( + partial(error_relative_global_dimensionless_synthesis), + lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)), + lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)), + id="error relative global dimensionless synthesis", + ), ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) From 8f4660b2be49b0afc3538bea5a8c16372eb5cce7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Feb 2023 03:26:38 +0000 Subject: [PATCH 05/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/image/ergas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/image/ergas.py b/src/torchmetrics/image/ergas.py index a4a75c64b8d..018f8144f00 100644 --- a/src/torchmetrics/image/ergas.py +++ b/src/torchmetrics/image/ergas.py @@ -102,7 +102,7 @@ def compute(self) -> Tensor: return _ergas_compute(preds, target, self.ratio, self.reduction) def plot( - self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. From 172f6202fbe13d2e555f7fde5b19f5d0f93b0f69 Mon Sep 17 00:00:00 2001 From: Paul Louis Date: Sun, 5 Feb 2023 22:54:15 -0500 Subject: [PATCH 06/22] Add plot() support for psnr in image --- examples/plotting.py | 20 +++++++++ src/torchmetrics/image/psnr.py | 58 ++++++++++++++++++++++++++ tests/unittests/utilities/test_plot.py | 7 ++++ 3 files changed, 85 insertions(+) diff --git a/examples/plotting.py b/examples/plotting.py index dd6aa255eaa..2938dc710d7 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -116,6 +116,25 @@ def error_relative_global_dimensionless_synthesis(): return fig, ax +def peak_signal_noise_ratio(): + from torchmetrics.image.psnr import PeakSignalNoiseRatio + + p = lambda: torch.tensor([[0.0, 1.0], [2.0, 3.0]]) + t = lambda: torch.tensor([[3.0, 2.0], [1.0, 0.0]]) + + # plot single value + metric = PeakSignalNoiseRatio() + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = PeakSignalNoiseRatio() + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + + if __name__ == "__main__": metrics_func = { "accuracy": accuracy_example, @@ -123,6 +142,7 @@ def error_relative_global_dimensionless_synthesis(): "confusion_matrix": confusion_matrix_example, "spectral_distortion_index": spectral_distortion_index_example, "error_relative_global_dimensionless_synthesis": error_relative_global_dimensionless_synthesis, + "peak_signal_noise_ratio": peak_signal_noise_ratio, } parser = argparse.ArgumentParser(description="Example script for plotting metrics.") diff --git a/src/torchmetrics/image/psnr.py b/src/torchmetrics/image/psnr.py index ca2e2ad6f5b..66796eb1352 100644 --- a/src/torchmetrics/image/psnr.py +++ b/src/torchmetrics/image/psnr.py @@ -20,6 +20,11 @@ from torchmetrics.functional.image.psnr import _psnr_compute, _psnr_update from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["PeakSignalNoiseRatio.plot"] class PeakSignalNoiseRatio(Metric): @@ -70,6 +75,7 @@ class PeakSignalNoiseRatio(Metric): is_differentiable: bool = True higher_is_better: bool = True full_state_update: bool = False + plot_options = {"lower_bound": 0.0, "upper_bound": 10.0} min_target: Tensor max_target: Tensor @@ -135,3 +141,55 @@ def compute(self) -> Tensor: sum_squared_error = torch.cat([values.flatten() for values in self.sum_squared_error]) total = torch.cat([values.flatten() for values in self.total]) return _psnr_compute(sum_squared_error, total, data_range, base=self.base, reduction=self.reduction) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + fig: Figure object + ax: Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + Examples: + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import PeakSignalNoiseRatio + >>> metric = PeakSignalNoiseRatio() + >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) + >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple value + >>> import torch + >>> from torchmetrics import PeakSignalNoiseRatio + >>> metric = PeakSignalNoiseRatio() + >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) + >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax \ No newline at end of file diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 0813a5b4057..7206187d482 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -19,6 +19,7 @@ import pytest import torch +from torchmetrics.functional import peak_signal_noise_ratio from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, @@ -63,6 +64,12 @@ lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)), id="error relative global dimensionless synthesis", ), + pytest.param( + partial(peak_signal_noise_ratio), + lambda: torch.tensor([[0.0, 1.0], [2.0, 3.0]]), + lambda: torch.tensor([[3.0, 2.0], [1.0, 0.0]]), + id="peak signal noise ratio", + ), ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) From f05c2d65eaa1df0d5bc207290ff949f3f5d18d64 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Feb 2023 03:55:05 +0000 Subject: [PATCH 07/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/image/psnr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/image/psnr.py b/src/torchmetrics/image/psnr.py index 66796eb1352..d3fa46d62f0 100644 --- a/src/torchmetrics/image/psnr.py +++ b/src/torchmetrics/image/psnr.py @@ -143,7 +143,7 @@ def compute(self) -> Tensor: return _psnr_compute(sum_squared_error, total, data_range, base=self.base, reduction=self.reduction) def plot( - self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -192,4 +192,4 @@ def plot( fig, ax = plot_single_or_multi_val( val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ ) - return fig, ax \ No newline at end of file + return fig, ax From 1a879074dfa3ddfefe15d5ede8f0a154bda41089 Mon Sep 17 00:00:00 2001 From: Paul Louis Date: Sun, 5 Feb 2023 23:03:24 -0500 Subject: [PATCH 08/22] Add plot() support for sam in image --- examples/plotting.py | 20 +++++++++ src/torchmetrics/image/d_lambda.py | 2 +- src/torchmetrics/image/ergas.py | 2 +- src/torchmetrics/image/psnr.py | 2 +- src/torchmetrics/image/sam.py | 60 +++++++++++++++++++++++++- tests/unittests/utilities/test_plot.py | 8 +++- 6 files changed, 89 insertions(+), 5 deletions(-) diff --git a/examples/plotting.py b/examples/plotting.py index 2938dc710d7..35eea25c3c1 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -135,6 +135,25 @@ def peak_signal_noise_ratio(): return fig, ax +def spectral_angle_mapper(): + from torchmetrics.image.sam import SpectralAngleMapper + + p = lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42)) + t = lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123)) + + # plot single value + metric = SpectralAngleMapper() + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = SpectralAngleMapper() + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + + if __name__ == "__main__": metrics_func = { "accuracy": accuracy_example, @@ -143,6 +162,7 @@ def peak_signal_noise_ratio(): "spectral_distortion_index": spectral_distortion_index_example, "error_relative_global_dimensionless_synthesis": error_relative_global_dimensionless_synthesis, "peak_signal_noise_ratio": peak_signal_noise_ratio, + "spectral_angle_mapper": spectral_angle_mapper } parser = argparse.ArgumentParser(description="Example script for plotting metrics.") diff --git a/src/torchmetrics/image/d_lambda.py b/src/torchmetrics/image/d_lambda.py index db737ee5d2c..f626c175bde 100644 --- a/src/torchmetrics/image/d_lambda.py +++ b/src/torchmetrics/image/d_lambda.py @@ -138,7 +138,7 @@ def plot( .. plot:: :scale: 75 - >>> # Example plotting multiple value + >>> # Example plotting multiple values >>> import torch >>> _ = torch.manual_seed(42) >>> from torchmetrics import SpectralDistortionIndex diff --git a/src/torchmetrics/image/ergas.py b/src/torchmetrics/image/ergas.py index 018f8144f00..75d6192b3ae 100644 --- a/src/torchmetrics/image/ergas.py +++ b/src/torchmetrics/image/ergas.py @@ -136,7 +136,7 @@ def plot( .. plot:: :scale: 75 - >>> # Example plotting multiple value + >>> # Example plotting multiple values >>> import torch >>> from torchmetrics import ErrorRelativeGlobalDimensionlessSynthesis >>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) diff --git a/src/torchmetrics/image/psnr.py b/src/torchmetrics/image/psnr.py index d3fa46d62f0..63bea5e6ee5 100644 --- a/src/torchmetrics/image/psnr.py +++ b/src/torchmetrics/image/psnr.py @@ -177,7 +177,7 @@ def plot( .. plot:: :scale: 75 - >>> # Example plotting multiple value + >>> # Example plotting multiple values >>> import torch >>> from torchmetrics import PeakSignalNoiseRatio >>> metric = PeakSignalNoiseRatio() diff --git a/src/torchmetrics/image/sam.py b/src/torchmetrics/image/sam.py index be01b73c798..6689626f91d 100644 --- a/src/torchmetrics/image/sam.py +++ b/src/torchmetrics/image/sam.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List +from typing import Any, List, Optional, Union, Sequence from torch import Tensor from typing_extensions import Literal @@ -20,6 +20,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["SpectralAngleMapper.plot"] class SpectralAngleMapper(Metric): @@ -62,6 +67,7 @@ class SpectralAngleMapper(Metric): higher_is_better: bool = False is_differentiable: bool = True full_state_update: bool = False + plot_options = {"lower_bound": 0.0, "upper_bound": 1.0} preds: List[Tensor] target: List[Tensor] @@ -92,3 +98,55 @@ def compute(self) -> Tensor: preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) return _sam_compute(preds, target, self.reduction) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + fig: Figure object + ax: Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + Examples: + + .. plot:: + :scale: 75 + + >>> # Example plotting single value + >>> import torch + >>> from torchmetrics import SpectralAngleMapper + >>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42)) + >>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123)) + >>> metric = SpectralAngleMapper() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import SpectralAngleMapper + >>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42)) + >>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123)) + >>> metric = SpectralAngleMapper() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax \ No newline at end of file diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 7206187d482..a98bfa3bdc8 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -19,7 +19,7 @@ import pytest import torch -from torchmetrics.functional import peak_signal_noise_ratio +from torchmetrics.functional import peak_signal_noise_ratio, spectral_angle_mapper from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, @@ -70,6 +70,12 @@ lambda: torch.tensor([[3.0, 2.0], [1.0, 0.0]]), id="peak signal noise ratio", ), + pytest.param( + partial(spectral_angle_mapper), + lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42)), + lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123)), + id="spectral angle mapper", + ), ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) From fbcf6d2b3d940401dc1fe0739301bf2382e42baa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Feb 2023 04:03:56 +0000 Subject: [PATCH 09/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/plotting.py | 2 +- src/torchmetrics/image/sam.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/plotting.py b/examples/plotting.py index 35eea25c3c1..c0882eccb0d 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -162,7 +162,7 @@ def spectral_angle_mapper(): "spectral_distortion_index": spectral_distortion_index_example, "error_relative_global_dimensionless_synthesis": error_relative_global_dimensionless_synthesis, "peak_signal_noise_ratio": peak_signal_noise_ratio, - "spectral_angle_mapper": spectral_angle_mapper + "spectral_angle_mapper": spectral_angle_mapper, } parser = argparse.ArgumentParser(description="Example script for plotting metrics.") diff --git a/src/torchmetrics/image/sam.py b/src/torchmetrics/image/sam.py index 6689626f91d..eec08caf201 100644 --- a/src/torchmetrics/image/sam.py +++ b/src/torchmetrics/image/sam.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Union, Sequence +from typing import Any, List, Optional, Sequence, Union from torch import Tensor from typing_extensions import Literal @@ -100,7 +100,7 @@ def compute(self) -> Tensor: return _sam_compute(preds, target, self.reduction) def plot( - self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -149,4 +149,4 @@ def plot( fig, ax = plot_single_or_multi_val( val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ ) - return fig, ax \ No newline at end of file + return fig, ax From 62191f8c723dec31f30de5311b6c95f1d200ce75 Mon Sep 17 00:00:00 2001 From: Paul Louis Date: Sun, 5 Feb 2023 23:18:17 -0500 Subject: [PATCH 10/22] Add plot() support for ssim in image --- examples/plotting.py | 40 +++++++++ src/torchmetrics/image/ssim.py | 111 +++++++++++++++++++++++++ tests/unittests/utilities/test_plot.py | 15 +++- 3 files changed, 165 insertions(+), 1 deletion(-) diff --git a/examples/plotting.py b/examples/plotting.py index c0882eccb0d..0a35ed4caeb 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -154,6 +154,44 @@ def spectral_angle_mapper(): return fig, ax +def structural_similarity_index_measure(): + from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure + + p = lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) + t = lambda: p() * 0.75 + + # plot single value + metric = StructuralSimilarityIndexMeasure() + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = StructuralSimilarityIndexMeasure() + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + + +def multiscale_structural_similarity_index_measure(): + from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure + + p = lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) + t = lambda: p() * 0.75 + + # plot single value + metric = MultiScaleStructuralSimilarityIndexMeasure() + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = MultiScaleStructuralSimilarityIndexMeasure() + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + + if __name__ == "__main__": metrics_func = { "accuracy": accuracy_example, @@ -163,6 +201,8 @@ def spectral_angle_mapper(): "error_relative_global_dimensionless_synthesis": error_relative_global_dimensionless_synthesis, "peak_signal_noise_ratio": peak_signal_noise_ratio, "spectral_angle_mapper": spectral_angle_mapper, + "structural_similarity_index_measure": structural_similarity_index_measure, + "multiscale_structural_similarity_index_measure": multiscale_structural_similarity_index_measure } parser = argparse.ArgumentParser(description="Example script for plotting metrics.") diff --git a/src/torchmetrics/image/ssim.py b/src/torchmetrics/image/ssim.py index 282fb8c2940..7633ff7ead3 100644 --- a/src/torchmetrics/image/ssim.py +++ b/src/torchmetrics/image/ssim.py @@ -20,6 +20,11 @@ from torchmetrics.functional.image.ssim import _multiscale_ssim_update, _ssim_check_inputs, _ssim_update from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["StructuralSimilarityIndexMeasure.plot", "MultiScaleStructuralSimilarityIndexMeasure.plot"] class StructuralSimilarityIndexMeasure(Metric): @@ -72,6 +77,7 @@ class StructuralSimilarityIndexMeasure(Metric): higher_is_better: bool = True is_differentiable: bool = True full_state_update: bool = False + plot_options = {"lower_bound": 0.0, "upper_bound": 1.0} preds: List[Tensor] target: List[Tensor] @@ -160,6 +166,58 @@ def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: return similarity + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + fig: Figure object + ax: Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + Examples: + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import StructuralSimilarityIndexMeasure + >>> preds = torch.rand([3, 3, 256, 256]) + >>> target = preds * 0.75 + >>> metric = StructuralSimilarityIndexMeasure(data_range=1.0) + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple value + >>> import torch + >>> from torchmetrics import StructuralSimilarityIndexMeasure + >>> preds = torch.rand([3, 3, 256, 256]) + >>> target = preds * 0.75 + >>> metric = StructuralSimilarityIndexMeasure(data_range=1.0) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax + class MultiScaleStructuralSimilarityIndexMeasure(Metric): """Computes `MultiScaleSSIM`_, Multi-scale Structural Similarity Index Measure, which is a generalization of @@ -219,6 +277,7 @@ class MultiScaleStructuralSimilarityIndexMeasure(Metric): higher_is_better: bool = True is_differentiable: bool = True full_state_update: bool = False + plot_options = {"lower_bound": 0.0, "upper_bound": 1.0} preds: List[Tensor] target: List[Tensor] @@ -309,3 +368,55 @@ def compute(self) -> Tensor: return self.similarity else: return self.similarity / self.total + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + fig: Figure object + ax: Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + Examples: + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import MultiScaleStructuralSimilarityIndexMeasure + >>> import torch + >>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) + >>> target = preds * 0.75 + >>> metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple value + >>> from torchmetrics import MultiScaleStructuralSimilarityIndexMeasure + >>> import torch + >>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) + >>> target = preds * 0.75 + >>> metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index a98bfa3bdc8..0ed7eac26ee 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -19,7 +19,8 @@ import pytest import torch -from torchmetrics.functional import peak_signal_noise_ratio, spectral_angle_mapper +from torchmetrics.functional import peak_signal_noise_ratio, spectral_angle_mapper, structural_similarity_index_measure, \ + multiscale_structural_similarity_index_measure from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, @@ -76,6 +77,18 @@ lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123)), id="spectral angle mapper", ), + pytest.param( + partial(structural_similarity_index_measure), + lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)), + lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) * 0.75, + id="structural similarity index_measure", + ), + pytest.param( + partial(multiscale_structural_similarity_index_measure), + lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)), + lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) * 0.75, + id="multiscale structural similarity index measure", + ), ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) From 3dc99efb76639fcb3eef217c281da8f0bf93473a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Feb 2023 04:19:25 +0000 Subject: [PATCH 11/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/plotting.py | 2 +- src/torchmetrics/image/ssim.py | 4 ++-- tests/unittests/utilities/test_plot.py | 8 ++++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/plotting.py b/examples/plotting.py index 0a35ed4caeb..39e5a1f14e7 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -202,7 +202,7 @@ def multiscale_structural_similarity_index_measure(): "peak_signal_noise_ratio": peak_signal_noise_ratio, "spectral_angle_mapper": spectral_angle_mapper, "structural_similarity_index_measure": structural_similarity_index_measure, - "multiscale_structural_similarity_index_measure": multiscale_structural_similarity_index_measure + "multiscale_structural_similarity_index_measure": multiscale_structural_similarity_index_measure, } parser = argparse.ArgumentParser(description="Example script for plotting metrics.") diff --git a/src/torchmetrics/image/ssim.py b/src/torchmetrics/image/ssim.py index 7633ff7ead3..38b13d948cc 100644 --- a/src/torchmetrics/image/ssim.py +++ b/src/torchmetrics/image/ssim.py @@ -167,7 +167,7 @@ def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: return similarity def plot( - self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -370,7 +370,7 @@ def compute(self) -> Tensor: return self.similarity / self.total def plot( - self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 0ed7eac26ee..0fed5546ac8 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -19,8 +19,12 @@ import pytest import torch -from torchmetrics.functional import peak_signal_noise_ratio, spectral_angle_mapper, structural_similarity_index_measure, \ - multiscale_structural_similarity_index_measure +from torchmetrics.functional import ( + multiscale_structural_similarity_index_measure, + peak_signal_noise_ratio, + spectral_angle_mapper, + structural_similarity_index_measure, +) from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, From 0e27f89a90eed010dfe6bd9c2fd0a62bb6830a55 Mon Sep 17 00:00:00 2001 From: Paul Louis Date: Sun, 5 Feb 2023 23:29:41 -0500 Subject: [PATCH 12/22] Add plot() support for uqi in image --- examples/plotting.py | 20 +++++++++ src/torchmetrics/image/ssim.py | 4 +- src/torchmetrics/image/uqi.py | 60 +++++++++++++++++++++++++- tests/unittests/utilities/test_plot.py | 14 +++--- 4 files changed, 89 insertions(+), 9 deletions(-) diff --git a/examples/plotting.py b/examples/plotting.py index 39e5a1f14e7..cdb3cbd6971 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -192,6 +192,25 @@ def multiscale_structural_similarity_index_measure(): return fig, ax +def universal_image_quality_index(): + from torchmetrics.image.uqi import UniversalImageQualityIndex + + p = lambda: torch.rand([16, 1, 16, 16]) + t = lambda: p() * 0.75 + + # plot single value + metric = UniversalImageQualityIndex() + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = UniversalImageQualityIndex() + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + + if __name__ == "__main__": metrics_func = { "accuracy": accuracy_example, @@ -203,6 +222,7 @@ def multiscale_structural_similarity_index_measure(): "spectral_angle_mapper": spectral_angle_mapper, "structural_similarity_index_measure": structural_similarity_index_measure, "multiscale_structural_similarity_index_measure": multiscale_structural_similarity_index_measure, + "universal_image_quality_index": universal_image_quality_index } parser = argparse.ArgumentParser(description="Example script for plotting metrics.") diff --git a/src/torchmetrics/image/ssim.py b/src/torchmetrics/image/ssim.py index 38b13d948cc..9e6bd8efc73 100644 --- a/src/torchmetrics/image/ssim.py +++ b/src/torchmetrics/image/ssim.py @@ -201,7 +201,7 @@ def plot( .. plot:: :scale: 75 - >>> # Example plotting multiple value + >>> # Example plotting multiple values >>> import torch >>> from torchmetrics import StructuralSimilarityIndexMeasure >>> preds = torch.rand([3, 3, 256, 256]) @@ -404,7 +404,7 @@ def plot( .. plot:: :scale: 75 - >>> # Example plotting multiple value + >>> # Example plotting multiple values >>> from torchmetrics import MultiScaleStructuralSimilarityIndexMeasure >>> import torch >>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) diff --git a/src/torchmetrics/image/uqi.py b/src/torchmetrics/image/uqi.py index a6062777f5b..e08d78660dc 100644 --- a/src/torchmetrics/image/uqi.py +++ b/src/torchmetrics/image/uqi.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence +from typing import Any, List, Optional, Sequence, Union from torch import Tensor from typing_extensions import Literal @@ -20,6 +20,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["UniversalImageQualityIndex.plot"] class UniversalImageQualityIndex(Metric): @@ -64,6 +69,7 @@ class UniversalImageQualityIndex(Metric): is_differentiable: bool = True higher_is_better: bool = True full_state_update: bool = False + plot_options = {"lower_bound": 0.0, "upper_bound": 1.0} preds: List[Tensor] target: List[Tensor] @@ -101,3 +107,55 @@ def compute(self) -> Tensor: preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) return _uqi_compute(preds, target, self.kernel_size, self.sigma, self.reduction, self.data_range) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + fig: Figure object + ax: Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + Examples: + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import UniversalImageQualityIndex + >>> preds = torch.rand([16, 1, 16, 16]) + >>> target = preds * 0.75 + >>> metric = UniversalImageQualityIndex() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import UniversalImageQualityIndex + >>> preds = torch.rand([16, 1, 16, 16]) + >>> target = preds * 0.75 + >>> metric = UniversalImageQualityIndex() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 0fed5546ac8..75748c7ca84 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -19,12 +19,8 @@ import pytest import torch -from torchmetrics.functional import ( - multiscale_structural_similarity_index_measure, - peak_signal_noise_ratio, - spectral_angle_mapper, - structural_similarity_index_measure, -) +from torchmetrics.functional import peak_signal_noise_ratio, spectral_angle_mapper, structural_similarity_index_measure, \ + multiscale_structural_similarity_index_measure, universal_image_quality_index from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, @@ -93,6 +89,12 @@ lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) * 0.75, id="multiscale structural similarity index measure", ), + pytest.param( + partial(universal_image_quality_index), + lambda: torch.rand([16, 1, 16, 16]), + lambda: torch.rand([16, 1, 16, 16]) * 0.75, + id="universal image quality index", + ), ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) From 07d99f228e2b46c57770c8d4b636e1c6cac278e0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Feb 2023 04:30:28 +0000 Subject: [PATCH 13/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/plotting.py | 2 +- src/torchmetrics/image/uqi.py | 2 +- tests/unittests/utilities/test_plot.py | 9 +++++++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/plotting.py b/examples/plotting.py index cdb3cbd6971..bd75949da86 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -222,7 +222,7 @@ def universal_image_quality_index(): "spectral_angle_mapper": spectral_angle_mapper, "structural_similarity_index_measure": structural_similarity_index_measure, "multiscale_structural_similarity_index_measure": multiscale_structural_similarity_index_measure, - "universal_image_quality_index": universal_image_quality_index + "universal_image_quality_index": universal_image_quality_index, } parser = argparse.ArgumentParser(description="Example script for plotting metrics.") diff --git a/src/torchmetrics/image/uqi.py b/src/torchmetrics/image/uqi.py index e08d78660dc..938a8824310 100644 --- a/src/torchmetrics/image/uqi.py +++ b/src/torchmetrics/image/uqi.py @@ -109,7 +109,7 @@ def compute(self) -> Tensor: return _uqi_compute(preds, target, self.kernel_size, self.sigma, self.reduction, self.data_range) def plot( - self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 75748c7ca84..a852d6c368a 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -19,8 +19,13 @@ import pytest import torch -from torchmetrics.functional import peak_signal_noise_ratio, spectral_angle_mapper, structural_similarity_index_measure, \ - multiscale_structural_similarity_index_measure, universal_image_quality_index +from torchmetrics.functional import ( + multiscale_structural_similarity_index_measure, + peak_signal_noise_ratio, + spectral_angle_mapper, + structural_similarity_index_measure, + universal_image_quality_index, +) from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, From 62cdfc3dd3b6d13d0acc66d8d5ad05c24a87a959 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Feb 2023 02:38:29 +0000 Subject: [PATCH 14/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/image/d_lambda.py | 1 - src/torchmetrics/image/ergas.py | 1 - src/torchmetrics/image/psnr.py | 1 - src/torchmetrics/image/sam.py | 1 - src/torchmetrics/image/ssim.py | 2 -- src/torchmetrics/image/uqi.py | 1 - 6 files changed, 7 deletions(-) diff --git a/src/torchmetrics/image/d_lambda.py b/src/torchmetrics/image/d_lambda.py index c7140559294..7ec80e2709e 100644 --- a/src/torchmetrics/image/d_lambda.py +++ b/src/torchmetrics/image/d_lambda.py @@ -121,7 +121,6 @@ def plot( If `matplotlib` is not installed Examples: - .. plot:: :scale: 75 diff --git a/src/torchmetrics/image/ergas.py b/src/torchmetrics/image/ergas.py index a82249af217..4726a71bcd3 100644 --- a/src/torchmetrics/image/ergas.py +++ b/src/torchmetrics/image/ergas.py @@ -120,7 +120,6 @@ def plot( If `matplotlib` is not installed Examples: - .. plot:: :scale: 75 diff --git a/src/torchmetrics/image/psnr.py b/src/torchmetrics/image/psnr.py index be01e249a16..0b8314ee82b 100644 --- a/src/torchmetrics/image/psnr.py +++ b/src/torchmetrics/image/psnr.py @@ -161,7 +161,6 @@ def plot( If `matplotlib` is not installed Examples: - .. plot:: :scale: 75 diff --git a/src/torchmetrics/image/sam.py b/src/torchmetrics/image/sam.py index 4e59da821e3..da11780d0dc 100644 --- a/src/torchmetrics/image/sam.py +++ b/src/torchmetrics/image/sam.py @@ -118,7 +118,6 @@ def plot( If `matplotlib` is not installed Examples: - .. plot:: :scale: 75 diff --git a/src/torchmetrics/image/ssim.py b/src/torchmetrics/image/ssim.py index a6410bdb5fc..77ed8ddcaa4 100644 --- a/src/torchmetrics/image/ssim.py +++ b/src/torchmetrics/image/ssim.py @@ -185,7 +185,6 @@ def plot( If `matplotlib` is not installed Examples: - .. plot:: :scale: 75 @@ -387,7 +386,6 @@ def plot( If `matplotlib` is not installed Examples: - .. plot:: :scale: 75 diff --git a/src/torchmetrics/image/uqi.py b/src/torchmetrics/image/uqi.py index 32f207dcc4e..ebb9c3eb291 100644 --- a/src/torchmetrics/image/uqi.py +++ b/src/torchmetrics/image/uqi.py @@ -127,7 +127,6 @@ def plot( If `matplotlib` is not installed Examples: - .. plot:: :scale: 75 From 77582fbaf69beafe2fb92ab91a60837546bc7936 Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 13 Feb 2023 19:03:13 +0100 Subject: [PATCH 15/22] returns --- src/torchmetrics/classification/accuracy.py | 3 +-- src/torchmetrics/classification/confusion_matrix.py | 3 +-- src/torchmetrics/image/d_lambda.py | 3 +-- src/torchmetrics/image/ergas.py | 3 +-- src/torchmetrics/image/psnr.py | 3 +-- src/torchmetrics/image/sam.py | 3 +-- src/torchmetrics/image/ssim.py | 6 ++---- src/torchmetrics/image/uqi.py | 3 +-- src/torchmetrics/regression/mse.py | 3 +-- 9 files changed, 10 insertions(+), 20 deletions(-) diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index f68a05db6db..45edba4d7ed 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -414,8 +414,7 @@ def plot( ax: An matplotlib axis object. If provided will add plot to that axis Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index 2f9277d76a1..896a2f48c12 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -239,8 +239,7 @@ def plot(self, val: Optional[Tensor] = None) -> _PLOT_OUT_TYPE: If no value is provided, will automatically call `metric.compute` and plot that result. Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: diff --git a/src/torchmetrics/image/d_lambda.py b/src/torchmetrics/image/d_lambda.py index 7ec80e2709e..5aef03b8680 100644 --- a/src/torchmetrics/image/d_lambda.py +++ b/src/torchmetrics/image/d_lambda.py @@ -113,8 +113,7 @@ def plot( ax: An matplotlib axis object. If provided will add plot to that axis Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: diff --git a/src/torchmetrics/image/ergas.py b/src/torchmetrics/image/ergas.py index 4726a71bcd3..02bb18c19e6 100644 --- a/src/torchmetrics/image/ergas.py +++ b/src/torchmetrics/image/ergas.py @@ -112,8 +112,7 @@ def plot( ax: An matplotlib axis object. If provided will add plot to that axis Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: diff --git a/src/torchmetrics/image/psnr.py b/src/torchmetrics/image/psnr.py index 0b8314ee82b..809bb6352ca 100644 --- a/src/torchmetrics/image/psnr.py +++ b/src/torchmetrics/image/psnr.py @@ -153,8 +153,7 @@ def plot( ax: An matplotlib axis object. If provided will add plot to that axis Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: diff --git a/src/torchmetrics/image/sam.py b/src/torchmetrics/image/sam.py index da11780d0dc..75ad402b6e1 100644 --- a/src/torchmetrics/image/sam.py +++ b/src/torchmetrics/image/sam.py @@ -110,8 +110,7 @@ def plot( ax: An matplotlib axis object. If provided will add plot to that axis Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: diff --git a/src/torchmetrics/image/ssim.py b/src/torchmetrics/image/ssim.py index 77ed8ddcaa4..3be432cae90 100644 --- a/src/torchmetrics/image/ssim.py +++ b/src/torchmetrics/image/ssim.py @@ -177,8 +177,7 @@ def plot( ax: An matplotlib axis object. If provided will add plot to that axis Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: @@ -378,8 +377,7 @@ def plot( ax: An matplotlib axis object. If provided will add plot to that axis Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: diff --git a/src/torchmetrics/image/uqi.py b/src/torchmetrics/image/uqi.py index ebb9c3eb291..c773a51d846 100644 --- a/src/torchmetrics/image/uqi.py +++ b/src/torchmetrics/image/uqi.py @@ -119,8 +119,7 @@ def plot( ax: An matplotlib axis object. If provided will add plot to that axis Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: diff --git a/src/torchmetrics/regression/mse.py b/src/torchmetrics/regression/mse.py index 1ed2ecf64f9..1b1e32986c5 100644 --- a/src/torchmetrics/regression/mse.py +++ b/src/torchmetrics/regression/mse.py @@ -93,8 +93,7 @@ def plot( ax: An matplotlib axis object. If provided will add plot to that axis Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: From 3d20634473e846540ef1b65ee91e11a105f52a04 Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 13 Feb 2023 19:03:59 +0100 Subject: [PATCH 16/22] examples --- src/torchmetrics/image/d_lambda.py | 1 - src/torchmetrics/image/ergas.py | 1 - src/torchmetrics/image/psnr.py | 1 - src/torchmetrics/image/sam.py | 1 - src/torchmetrics/image/ssim.py | 2 -- src/torchmetrics/image/uqi.py | 1 - 6 files changed, 7 deletions(-) diff --git a/src/torchmetrics/image/d_lambda.py b/src/torchmetrics/image/d_lambda.py index 5aef03b8680..2c5d0850b20 100644 --- a/src/torchmetrics/image/d_lambda.py +++ b/src/torchmetrics/image/d_lambda.py @@ -119,7 +119,6 @@ def plot( ModuleNotFoundError: If `matplotlib` is not installed - Examples: .. plot:: :scale: 75 diff --git a/src/torchmetrics/image/ergas.py b/src/torchmetrics/image/ergas.py index 02bb18c19e6..7b0260bdf18 100644 --- a/src/torchmetrics/image/ergas.py +++ b/src/torchmetrics/image/ergas.py @@ -118,7 +118,6 @@ def plot( ModuleNotFoundError: If `matplotlib` is not installed - Examples: .. plot:: :scale: 75 diff --git a/src/torchmetrics/image/psnr.py b/src/torchmetrics/image/psnr.py index 809bb6352ca..2d83b2b98d3 100644 --- a/src/torchmetrics/image/psnr.py +++ b/src/torchmetrics/image/psnr.py @@ -159,7 +159,6 @@ def plot( ModuleNotFoundError: If `matplotlib` is not installed - Examples: .. plot:: :scale: 75 diff --git a/src/torchmetrics/image/sam.py b/src/torchmetrics/image/sam.py index 75ad402b6e1..a1ba80428dc 100644 --- a/src/torchmetrics/image/sam.py +++ b/src/torchmetrics/image/sam.py @@ -116,7 +116,6 @@ def plot( ModuleNotFoundError: If `matplotlib` is not installed - Examples: .. plot:: :scale: 75 diff --git a/src/torchmetrics/image/ssim.py b/src/torchmetrics/image/ssim.py index 3be432cae90..656483997ad 100644 --- a/src/torchmetrics/image/ssim.py +++ b/src/torchmetrics/image/ssim.py @@ -183,7 +183,6 @@ def plot( ModuleNotFoundError: If `matplotlib` is not installed - Examples: .. plot:: :scale: 75 @@ -383,7 +382,6 @@ def plot( ModuleNotFoundError: If `matplotlib` is not installed - Examples: .. plot:: :scale: 75 diff --git a/src/torchmetrics/image/uqi.py b/src/torchmetrics/image/uqi.py index c773a51d846..7f57f089f23 100644 --- a/src/torchmetrics/image/uqi.py +++ b/src/torchmetrics/image/uqi.py @@ -125,7 +125,6 @@ def plot( ModuleNotFoundError: If `matplotlib` is not installed - Examples: .. plot:: :scale: 75 From 682992f0b7486b7afeda2bef3e8df468cf2dcca7 Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 13 Feb 2023 19:08:08 +0100 Subject: [PATCH 17/22] docs --- examples/plotting.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/plotting.py b/examples/plotting.py index 6a66a20589c..a50dd1dd8d2 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -82,6 +82,7 @@ def confusion_matrix_example(): def spectral_distortion_index_example(): + """Plot spectral distortion index example example.""" from torchmetrics.image.d_lambda import SpectralDistortionIndex p = lambda: torch.rand([16, 3, 16, 16]) @@ -101,6 +102,7 @@ def spectral_distortion_index_example(): def error_relative_global_dimensionless_synthesis(): + """Plot error relative global dimensionless synthesis example.""" from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis p = lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) @@ -120,6 +122,7 @@ def error_relative_global_dimensionless_synthesis(): def peak_signal_noise_ratio(): + """Plot peak signal noise ratio example.""" from torchmetrics.image.psnr import PeakSignalNoiseRatio p = lambda: torch.tensor([[0.0, 1.0], [2.0, 3.0]]) @@ -139,6 +142,7 @@ def peak_signal_noise_ratio(): def spectral_angle_mapper(): + """Plot spectral angle mapper example.""" from torchmetrics.image.sam import SpectralAngleMapper p = lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42)) @@ -158,6 +162,7 @@ def spectral_angle_mapper(): def structural_similarity_index_measure(): + """Plot structural similarity index measure example.""" from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure p = lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) @@ -177,6 +182,7 @@ def structural_similarity_index_measure(): def multiscale_structural_similarity_index_measure(): + """Plot multiscale structural similarity index measure example.""" from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure p = lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) @@ -196,6 +202,7 @@ def multiscale_structural_similarity_index_measure(): def universal_image_quality_index(): + """Plot universal image quality index example.""" from torchmetrics.image.uqi import UniversalImageQualityIndex p = lambda: torch.rand([16, 1, 16, 16]) From c0a7b30c565645c7d8b78e6e5c0be6dad3fc39ec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 18 Feb 2023 06:34:17 +0000 Subject: [PATCH 18/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/utilities/test_plot.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 6c4c827776a..1b611aa614c 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -22,13 +22,13 @@ from torchmetrics.functional import ( multiscale_structural_similarity_index_measure, peak_signal_noise_ratio, - spectral_angle_mapper, - structural_similarity_index_measure, - universal_image_quality_index, scale_invariant_signal_distortion_ratio, scale_invariant_signal_noise_ratio, signal_distortion_ratio, signal_noise_ratio, + spectral_angle_mapper, + structural_similarity_index_measure, + universal_image_quality_index, ) from torchmetrics.functional.audio import short_time_objective_intelligibility from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality From 116c1a5f91dcdf516b0ae3dcbed99bb84a78b4b2 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 21 Feb 2023 00:56:29 +0100 Subject: [PATCH 19/22] fix --- tests/unittests/utilities/test_plot.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 1b611aa614c..d55e5ab4f40 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -51,71 +51,72 @@ binary_accuracy, lambda: torch.rand(100), lambda: torch.randint(2, (100,)), - id="binary", + id="binary" ), pytest.param( partial(multiclass_accuracy, num_classes=3), lambda: torch.randint(3, (100,)), lambda: torch.randint(3, (100,)), - id="multiclass", + id="multiclass" ), pytest.param( partial(multiclass_accuracy, num_classes=3, average=None), lambda: torch.randint(3, (100,)), lambda: torch.randint(3, (100,)), - id="multiclass and average=None", + id="multiclass and average=None" ), pytest.param( partial(spectral_distortion_index), lambda: torch.rand([16, 3, 16, 16]), lambda: torch.rand([16, 3, 16, 16]), - id="spectral distortion index", + id="spectral distortion index" ), pytest.param( partial(error_relative_global_dimensionless_synthesis), lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)), lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)), - id="error relative global dimensionless synthesis", + id="error relative global dimensionless synthesis" ), pytest.param( partial(peak_signal_noise_ratio), lambda: torch.tensor([[0.0, 1.0], [2.0, 3.0]]), lambda: torch.tensor([[3.0, 2.0], [1.0, 0.0]]), - id="peak signal noise ratio", + id="peak signal noise ratio" ), pytest.param( partial(spectral_angle_mapper), lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42)), lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123)), - id="spectral angle mapper", + id="spectral angle mapper" ), pytest.param( partial(structural_similarity_index_measure), lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)), lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) * 0.75, - id="structural similarity index_measure", + id="structural similarity index_measure" ), pytest.param( partial(multiscale_structural_similarity_index_measure), lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)), lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) * 0.75, - id="multiscale structural similarity index measure", + id="multiscale structural similarity index measure" ), pytest.param( partial(universal_image_quality_index), lambda: torch.rand([16, 1, 16, 16]), lambda: torch.rand([16, 1, 16, 16]) * 0.75, - id="universal image quality index", + id="universal image quality index"), + pytest.param( partial(perceptual_evaluation_speech_quality, fs=8000, mode="nb"), lambda: torch.randn(8000), lambda: torch.randn(8000), - id="perceptual_evaluation_speech_quality", + id="perceptual_evaluation_speech_quality" ), pytest.param( partial(signal_distortion_ratio), lambda: torch.randn(8000), lambda: torch.randn(8000), - id="signal_distortion_ratio", + id="signal_distortion_ratio" ), pytest.param( partial(scale_invariant_signal_distortion_ratio), From 8d78bb3efab41f8dccccf520e16f8773ca3e7555 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Feb 2023 23:56:58 +0000 Subject: [PATCH 20/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/utilities/test_plot.py | 30 +++++++++++--------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index d55e5ab4f40..f9d251c49ea 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -47,76 +47,72 @@ @pytest.mark.parametrize( ("metric", "preds", "target"), [ - pytest.param( - binary_accuracy, - lambda: torch.rand(100), - lambda: torch.randint(2, (100,)), - id="binary" - ), + pytest.param(binary_accuracy, lambda: torch.rand(100), lambda: torch.randint(2, (100,)), id="binary"), pytest.param( partial(multiclass_accuracy, num_classes=3), lambda: torch.randint(3, (100,)), lambda: torch.randint(3, (100,)), - id="multiclass" + id="multiclass", ), pytest.param( partial(multiclass_accuracy, num_classes=3, average=None), lambda: torch.randint(3, (100,)), lambda: torch.randint(3, (100,)), - id="multiclass and average=None" + id="multiclass and average=None", ), pytest.param( partial(spectral_distortion_index), lambda: torch.rand([16, 3, 16, 16]), lambda: torch.rand([16, 3, 16, 16]), - id="spectral distortion index" + id="spectral distortion index", ), pytest.param( partial(error_relative_global_dimensionless_synthesis), lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)), lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)), - id="error relative global dimensionless synthesis" + id="error relative global dimensionless synthesis", ), pytest.param( partial(peak_signal_noise_ratio), lambda: torch.tensor([[0.0, 1.0], [2.0, 3.0]]), lambda: torch.tensor([[3.0, 2.0], [1.0, 0.0]]), - id="peak signal noise ratio" + id="peak signal noise ratio", ), pytest.param( partial(spectral_angle_mapper), lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42)), lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123)), - id="spectral angle mapper" + id="spectral angle mapper", ), pytest.param( partial(structural_similarity_index_measure), lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)), lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) * 0.75, - id="structural similarity index_measure" + id="structural similarity index_measure", ), pytest.param( partial(multiscale_structural_similarity_index_measure), lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)), lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) * 0.75, - id="multiscale structural similarity index measure" + id="multiscale structural similarity index measure", ), pytest.param( partial(universal_image_quality_index), lambda: torch.rand([16, 1, 16, 16]), lambda: torch.rand([16, 1, 16, 16]) * 0.75, - id="universal image quality index"), + id="universal image quality index", + ), pytest.param( partial(perceptual_evaluation_speech_quality, fs=8000, mode="nb"), lambda: torch.randn(8000), lambda: torch.randn(8000), - id="perceptual_evaluation_speech_quality" + id="perceptual_evaluation_speech_quality", ), pytest.param( partial(signal_distortion_ratio), lambda: torch.randn(8000), lambda: torch.randn(8000), - id="signal_distortion_ratio" + id="signal_distortion_ratio", ), pytest.param( partial(scale_invariant_signal_distortion_ratio), From 4f952dd7edd24e15dc569a1b64ae83a780be055a Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 21 Feb 2023 12:51:42 +0100 Subject: [PATCH 21/22] kwargs --- src/torchmetrics/image/d_lambda.py | 2 +- src/torchmetrics/image/ergas.py | 2 +- src/torchmetrics/image/sam.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/image/d_lambda.py b/src/torchmetrics/image/d_lambda.py index 2c5d0850b20..5ed1685fbe0 100644 --- a/src/torchmetrics/image/d_lambda.py +++ b/src/torchmetrics/image/d_lambda.py @@ -149,6 +149,6 @@ def plot( """ val = val or self.compute() fig, ax = plot_single_or_multi_val( - val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + val, ax=ax, higher_is_better=self.higher_is_better, name=self.__class__.__name__, **self.plot_options ) return fig, ax diff --git a/src/torchmetrics/image/ergas.py b/src/torchmetrics/image/ergas.py index 7b0260bdf18..3e563d23dbe 100644 --- a/src/torchmetrics/image/ergas.py +++ b/src/torchmetrics/image/ergas.py @@ -146,6 +146,6 @@ def plot( """ val = val or self.compute() fig, ax = plot_single_or_multi_val( - val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + val, ax=ax, higher_is_better=self.higher_is_better, name=self.__class__.__name__, **self.plot_options ) return fig, ax diff --git a/src/torchmetrics/image/sam.py b/src/torchmetrics/image/sam.py index a1ba80428dc..566467e7505 100644 --- a/src/torchmetrics/image/sam.py +++ b/src/torchmetrics/image/sam.py @@ -144,6 +144,6 @@ def plot( """ val = val or self.compute() fig, ax = plot_single_or_multi_val( - val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + val, ax=ax, higher_is_better=self.higher_is_better, name=self.__class__.__name__, **self.plot_options ) return fig, ax From 487de55f0cf8b4d4cee021425bf272ccd269dde1 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 22 Feb 2023 18:43:55 +0100 Subject: [PATCH 22/22] typing --- src/torchmetrics/image/d_lambda.py | 10 +++++++--- src/torchmetrics/image/ergas.py | 10 +++++++--- src/torchmetrics/image/sam.py | 10 +++++++--- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/torchmetrics/image/d_lambda.py b/src/torchmetrics/image/d_lambda.py index 5ed1685fbe0..4386e23f849 100644 --- a/src/torchmetrics/image/d_lambda.py +++ b/src/torchmetrics/image/d_lambda.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union from torch import Tensor from typing_extensions import Literal @@ -65,7 +65,6 @@ class SpectralDistortionIndex(Metric): higher_is_better: bool = True is_differentiable: bool = True full_state_update: bool = False - plot_options = {"lower_bound": 0.0, "upper_bound": 1.0} preds: List[Tensor] target: List[Tensor] @@ -149,6 +148,11 @@ def plot( """ val = val or self.compute() fig, ax = plot_single_or_multi_val( - val, ax=ax, higher_is_better=self.higher_is_better, name=self.__class__.__name__, **self.plot_options + val, + ax=ax, + higher_is_better=self.higher_is_better, + name=self.__class__.__name__, + lower_bound=0.0, + upper_bound=1.0, ) return fig, ax diff --git a/src/torchmetrics/image/ergas.py b/src/torchmetrics/image/ergas.py index 3e563d23dbe..1c71a42540e 100644 --- a/src/torchmetrics/image/ergas.py +++ b/src/torchmetrics/image/ergas.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union from torch import Tensor from typing_extensions import Literal @@ -66,7 +66,6 @@ class ErrorRelativeGlobalDimensionlessSynthesis(Metric): higher_is_better: bool = False is_differentiable: bool = True full_state_update: bool = False - plot_options = {"lower_bound": 0.0, "upper_bound": 1.0} preds: List[Tensor] target: List[Tensor] @@ -146,6 +145,11 @@ def plot( """ val = val or self.compute() fig, ax = plot_single_or_multi_val( - val, ax=ax, higher_is_better=self.higher_is_better, name=self.__class__.__name__, **self.plot_options + val, + ax=ax, + higher_is_better=self.higher_is_better, + name=self.__class__.__name__, + lower_bound=0.0, + upper_bound=1.0, ) return fig, ax diff --git a/src/torchmetrics/image/sam.py b/src/torchmetrics/image/sam.py index 566467e7505..1d1611a5afd 100644 --- a/src/torchmetrics/image/sam.py +++ b/src/torchmetrics/image/sam.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union from torch import Tensor from typing_extensions import Literal @@ -67,7 +67,6 @@ class SpectralAngleMapper(Metric): higher_is_better: bool = False is_differentiable: bool = True full_state_update: bool = False - plot_options = {"lower_bound": 0.0, "upper_bound": 1.0} preds: List[Tensor] target: List[Tensor] @@ -144,6 +143,11 @@ def plot( """ val = val or self.compute() fig, ax = plot_single_or_multi_val( - val, ax=ax, higher_is_better=self.higher_is_better, name=self.__class__.__name__, **self.plot_options + val, + ax=ax, + higher_is_better=self.higher_is_better, + name=self.__class__.__name__, + lower_bound=0.0, + upper_bound=1.0, ) return fig, ax