Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for plot() in image metrics #1480

Merged
merged 42 commits into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
371b898
Start out with support for plot() in image metrics
venomouscyanide Feb 5, 2023
3d54317
Merge branch 'master' into add_plots_for_image_metrics
venomouscyanide Feb 5, 2023
dc22016
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2023
4159fde
Complete CHANGELOG
venomouscyanide Feb 5, 2023
644ae1d
Add plot() support for ergas in image
venomouscyanide Feb 6, 2023
8f4660b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2023
172f620
Add plot() support for psnr in image
venomouscyanide Feb 6, 2023
f05c2d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2023
1a87907
Add plot() support for sam in image
venomouscyanide Feb 6, 2023
fbcf6d2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2023
62191f8
Add plot() support for ssim in image
venomouscyanide Feb 6, 2023
3dc99ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2023
0e27f89
Add plot() support for uqi in image
venomouscyanide Feb 6, 2023
07d99f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2023
961e47b
Merge branch 'master' into add_plots_for_image_metrics
Borda Feb 7, 2023
62cdfc3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 7, 2023
613ee58
Merge branch 'master' into add_plots_for_image_metrics
venomouscyanide Feb 7, 2023
77582fb
returns
Borda Feb 13, 2023
3d20634
examples
Borda Feb 13, 2023
682992f
docs
Borda Feb 13, 2023
3cfc83c
Merge branch 'master' into add_plots_for_image_metrics
Borda Feb 14, 2023
92e6080
Merge branch 'master' into add_plots_for_image_metrics
Borda Feb 18, 2023
c0a7b30
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2023
7b4b41f
Merge branch 'master' into add_plots_for_image_metrics
mergify[bot] Feb 20, 2023
436324b
Merge branch 'master' into add_plots_for_image_metrics
Borda Feb 20, 2023
9757dec
Merge branch 'master' into add_plots_for_image_metrics
mergify[bot] Feb 20, 2023
b1eb793
Merge branch 'master' into add_plots_for_image_metrics
mergify[bot] Feb 20, 2023
7fc55e3
Merge branch 'master' into add_plots_for_image_metrics
mergify[bot] Feb 20, 2023
3d6ec13
Merge branch 'master' into add_plots_for_image_metrics
Borda Feb 20, 2023
9d93ac5
Merge branch 'master' into add_plots_for_image_metrics
mergify[bot] Feb 20, 2023
116c1a5
fix
Borda Feb 20, 2023
8d78bb3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2023
2aafbe1
Merge branch 'master' into add_plots_for_image_metrics
mergify[bot] Feb 21, 2023
4f952dd
kwargs
Borda Feb 21, 2023
8c2d2aa
Merge branch 'master' into add_plots_for_image_metrics
Borda Feb 22, 2023
ddcc004
Merge branch 'master' into add_plots_for_image_metrics
Borda Feb 22, 2023
0b30455
Merge branch 'master' into add_plots_for_image_metrics
mergify[bot] Feb 22, 2023
d67ed5d
Merge branch 'master' into add_plots_for_image_metrics
mergify[bot] Feb 22, 2023
9b86681
Merge branch 'master' into add_plots_for_image_metrics
mergify[bot] Feb 22, 2023
14ab6ec
Merge branch 'master' into add_plots_for_image_metrics
Borda Feb 22, 2023
487de55
typing
Borda Feb 22, 2023
c99063b
Merge branch 'master' into add_plots_for_image_metrics
mergify[bot] Feb 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support for plotting of metrics through `.plot()` method (
[#1328](https://github.com/Lightning-AI/metrics/pull/1328),
[#1481](https://github.com/Lightning-AI/metrics/pull/1481)
[#1481](https://github.com/Lightning-AI/metrics/pull/1481),
[#1480](https://github.com/Lightning-AI/metrics/pull/1480)
)


Expand Down
148 changes: 147 additions & 1 deletion examples/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,147 @@ def confusion_matrix_example():
return fig, ax


if __name__ == "__main__":
def spectral_distortion_index_example():
"""Plot spectral distortion index example example."""
from torchmetrics.image.d_lambda import SpectralDistortionIndex

p = lambda: torch.rand([16, 3, 16, 16])
t = lambda: torch.rand([16, 3, 16, 16])

# plot single value
metric = SpectralDistortionIndex()
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = SpectralDistortionIndex()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def error_relative_global_dimensionless_synthesis():
"""Plot error relative global dimensionless synthesis example."""
from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis

p = lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
t = lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))

# plot single value
metric = ErrorRelativeGlobalDimensionlessSynthesis()
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = ErrorRelativeGlobalDimensionlessSynthesis()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def peak_signal_noise_ratio():
"""Plot peak signal noise ratio example."""
from torchmetrics.image.psnr import PeakSignalNoiseRatio

p = lambda: torch.tensor([[0.0, 1.0], [2.0, 3.0]])
t = lambda: torch.tensor([[3.0, 2.0], [1.0, 0.0]])

# plot single value
metric = PeakSignalNoiseRatio()
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = PeakSignalNoiseRatio()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def spectral_angle_mapper():
"""Plot spectral angle mapper example."""
from torchmetrics.image.sam import SpectralAngleMapper

p = lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42))
t = lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123))

# plot single value
metric = SpectralAngleMapper()
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = SpectralAngleMapper()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def structural_similarity_index_measure():
"""Plot structural similarity index measure example."""
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure

p = lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42))
t = lambda: p() * 0.75

# plot single value
metric = StructuralSimilarityIndexMeasure()
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = StructuralSimilarityIndexMeasure()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def multiscale_structural_similarity_index_measure():
"""Plot multiscale structural similarity index measure example."""
from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure

p = lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42))
t = lambda: p() * 0.75

# plot single value
metric = MultiScaleStructuralSimilarityIndexMeasure()
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = MultiScaleStructuralSimilarityIndexMeasure()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def universal_image_quality_index():
"""Plot universal image quality index example."""
from torchmetrics.image.uqi import UniversalImageQualityIndex

p = lambda: torch.rand([16, 1, 16, 16])
t = lambda: p() * 0.75

# plot single value
metric = UniversalImageQualityIndex()
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = UniversalImageQualityIndex()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


if __name__ == "__main__":
metrics_func = {
"accuracy": accuracy_example,
"pesq": pesq_example,
Expand All @@ -235,6 +374,13 @@ def confusion_matrix_example():
"stoi": stoi_example,
"mean_squared_error": mean_squared_error_example,
"confusion_matrix": confusion_matrix_example,
"spectral_distortion_index": spectral_distortion_index_example,
"error_relative_global_dimensionless_synthesis": error_relative_global_dimensionless_synthesis,
"peak_signal_noise_ratio": peak_signal_noise_ratio,
"spectral_angle_mapper": spectral_angle_mapper,
"structural_similarity_index_measure": structural_similarity_index_measure,
"multiscale_structural_similarity_index_measure": multiscale_structural_similarity_index_measure,
"universal_image_quality_index": universal_image_quality_index,
}

parser = argparse.ArgumentParser(description="Example script for plotting metrics.")
Expand Down
3 changes: 1 addition & 2 deletions src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,7 @@ def plot(
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
fig: Figure object
ax: Axes object
Figure and Axes object

Raises:
ModuleNotFoundError:
Expand Down
3 changes: 1 addition & 2 deletions src/torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,7 @@ def plot(self, val: Optional[Tensor] = None) -> _PLOT_OUT_TYPE:
If no value is provided, will automatically call `metric.compute` and plot that result.

Returns:
fig: Figure object
ax: Axes object
Figure and Axes object

Raises:
ModuleNotFoundError:
Expand Down
63 changes: 62 additions & 1 deletion src/torchmetrics/image/d_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List
from typing import Any, Dict, List, Optional, Sequence, Union

from torch import Tensor
from typing_extensions import Literal
Expand All @@ -20,6 +20,11 @@
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SpectralDistortionIndex.plot"]


class SpectralDistortionIndex(Metric):
Expand Down Expand Up @@ -95,3 +100,59 @@ def compute(self) -> Tensor:
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _spectral_distortion_index_compute(preds, target, self.p, self.reduction)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> metric = SpectralDistortionIndex()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> metric = SpectralDistortionIndex()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val,
ax=ax,
higher_is_better=self.higher_is_better,
name=self.__class__.__name__,
lower_bound=0.0,
upper_bound=1.0,
)
return fig, ax
61 changes: 60 additions & 1 deletion src/torchmetrics/image/ergas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Union
from typing import Any, Dict, List, Optional, Sequence, Union

from torch import Tensor
from typing_extensions import Literal
Expand All @@ -21,6 +21,11 @@
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["ErrorRelativeGlobalDimensionlessSynthesis.plot"]


class ErrorRelativeGlobalDimensionlessSynthesis(Metric):
Expand Down Expand Up @@ -94,3 +99,57 @@ def compute(self) -> Tensor:
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _ergas_compute(preds, target, self.ratio, self.reduction)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics import ErrorRelativeGlobalDimensionlessSynthesis
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> target = preds * 0.75
>>> metric = ErrorRelativeGlobalDimensionlessSynthesis()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics import ErrorRelativeGlobalDimensionlessSynthesis
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> target = preds * 0.75
>>> metric = ErrorRelativeGlobalDimensionlessSynthesis()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val,
ax=ax,
higher_is_better=self.higher_is_better,
name=self.__class__.__name__,
lower_bound=0.0,
upper_bound=1.0,
)
return fig, ax
Loading