diff --git a/CHANGELOG.md b/CHANGELOG.md index 12f7b95418e..648cde46870 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `MutualInformationScore` metric to cluster package ([#2008](https://github.com/Lightning-AI/torchmetrics/pull/2008) +- Added `RandScore` metric to cluster package ([#2025](https://github.com/Lightning-AI/torchmetrics/pull/2025) + + ### Changed - diff --git a/docs/source/clustering/rand_score.rst b/docs/source/clustering/rand_score.rst new file mode 100644 index 00000000000..62650c2d454 --- /dev/null +++ b/docs/source/clustering/rand_score.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Rand Score + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/clustering.svg + :tags: Clustering + +.. include:: ../links.rst + +########## +Rand Score +########## + +Module Interface +________________ + +.. autoclass:: torchmetrics.clustering.RandScore + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.clustering.rand_score diff --git a/docs/source/links.rst b/docs/source/links.rst index 7627490c661..7e875191a1f 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -152,3 +152,4 @@ .. _GIOU: https://arxiv.org/abs/1902.09630 .. _Mutual Information Score: https://en.wikipedia.org/wiki/Mutual_information .. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools +.. _Rand Score: https://link.springer.com/article/10.1007/BF01908075 diff --git a/src/torchmetrics/clustering/__init__.py b/src/torchmetrics/clustering/__init__.py index baeb8c88d31..118e59cc451 100644 --- a/src/torchmetrics/clustering/__init__.py +++ b/src/torchmetrics/clustering/__init__.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.clustering.mutual_info_score import MutualInfoScore +from torchmetrics.clustering.rand_score import RandScore __all__ = [ "MutualInfoScore", + "RandScore", ] diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index 504d12c2718..b85d60a6f71 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -41,8 +41,8 @@ class MutualInfoScore(Metric): As input to ``forward`` and ``update`` the metric accepts the following input: - - ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)`` - - ``target`` (:class:`~torch.Tensor`): either single output tensor with shape ``(N,)`` + - ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels + - ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels As output of ``forward`` and ``compute`` the metric returns the following output: diff --git a/src/torchmetrics/clustering/rand_score.py b/src/torchmetrics/clustering/rand_score.py new file mode 100644 index 00000000000..a7fa5bb83f8 --- /dev/null +++ b/src/torchmetrics/clustering/rand_score.py @@ -0,0 +1,122 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Sequence, Union + +from torch import Tensor + +from torchmetrics.functional.clustering.rand_score import rand_score +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["RandScore.plot"] + + +class RandScore(Metric): + r"""Compute `Rand Score`_ (alternatively known as Rand Index). + + .. math:: + RS(U, V) = \text{number of agreeing pairs} / \text{number of pairs} + + The number of agreeing pairs is every :math:`(i, j)` pair of samples where :math:`i \in U` and :math:`j \in V` + (the predicted and true clusterings, respectively) that are in the same cluster for both clusterings. + + The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields the same rand score. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels + - ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``rand_score`` (:class:`~torch.Tensor`): A tensor with the Rand Score + + Args: + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> from torchmetrics.clustering import RandScore + >>> preds = torch.tensor([2, 1, 0, 1, 0]) + >>> target = torch.tensor([0, 2, 1, 1, 0]) + >>> metric = RandScore() + >>> metric(preds, target) + tensor(0.6000) + + """ + + is_differentiable = True + higher_is_better = None + full_state_update: bool = True + plot_lower_bound: float = 0.0 + preds: List[Tensor] + target: List[Tensor] + contingency: Tensor + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + self.preds.append(preds) + self.target.append(target) + + def compute(self) -> Tensor: + """Compute rand score over state.""" + return rand_score(dim_zero_cat(self.preds), dim_zero_cat(self.target)) + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.clustering import RandScore + >>> metric = RandScore() + >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.clustering import RandScore + >>> metric = RandScore() + >>> for _ in range(10): + ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/functional/clustering/__init__.py b/src/torchmetrics/functional/clustering/__init__.py index c6f46126ca3..a2c3c110b1d 100644 --- a/src/torchmetrics/functional/clustering/__init__.py +++ b/src/torchmetrics/functional/clustering/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score +from torchmetrics.functional.clustering.rand_score import rand_score -__all__ = ["mutual_info_score"] +__all__ = ["mutual_info_score", "rand_score"] diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index f7c7cbfa587..a729726436e 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -64,8 +64,8 @@ def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor: """Compute mutual information between two clusterings. Args: - preds: predicted classes - target: ground truth classes + preds: predicted cluster labels + target: ground truth cluster labels Example: >>> from torchmetrics.functional.clustering import mutual_info_score diff --git a/src/torchmetrics/functional/clustering/rand_score.py b/src/torchmetrics/functional/clustering/rand_score.py new file mode 100644 index 00000000000..2e6848f46d5 --- /dev/null +++ b/src/torchmetrics/functional/clustering/rand_score.py @@ -0,0 +1,82 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import Tensor + +from torchmetrics.functional.clustering.utils import ( + calcualte_pair_cluster_confusion_matrix, + calculate_contingency_matrix, + check_cluster_labels, +) + + +def _rand_score_update(preds: Tensor, target: Tensor) -> Tensor: + """Update and return variables required to compute the rand score. + + Args: + preds: predicted cluster labels + target: ground truth cluster labels + + Returns: + contingency: contingency matrix + + """ + check_cluster_labels(preds, target) + return calculate_contingency_matrix(preds, target) + + +def _rand_score_compute(contingency: Tensor) -> Tensor: + """Compute the rand score based on the contingency matrix. + + Args: + contingency: contingency matrix + + Returns: + rand_score: rand score + + """ + pair_matrix = calcualte_pair_cluster_confusion_matrix(contingency=contingency) + + numerator = pair_matrix.diagonal().sum() + denominator = pair_matrix.sum() + if numerator == denominator or denominator == 0: + # Special limit cases: no clustering since the data is not split; + # or trivial clustering where each document is assigned a unique + # cluster. These are perfect matches hence return 1.0. + return torch.ones_like(numerator, dtype=torch.float32) + + return numerator / denominator + + +def rand_score(preds: Tensor, target: Tensor) -> Tensor: + """Compute the Rand score between two clusterings. + + Args: + preds: predicted cluster labels + target: ground truth cluster labels + + Returns: + scalar tensor with the rand score + + Example: + >>> from torchmetrics.functional.clustering import rand_score + >>> import torch + >>> rand_score(torch.tensor([0, 0, 1, 1]), torch.tensor([1, 1, 0, 0])) + tensor(1.) + >>> rand_score(torch.tensor([0, 0, 1, 2]), torch.tensor([0, 0, 1, 1])) + tensor(0.8333) + + """ + contingency = _rand_score_update(preds, target) + return _rand_score_compute(contingency) diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py index 64dff0377ee..c50a2b03f5b 100644 --- a/src/torchmetrics/functional/clustering/utils.py +++ b/src/torchmetrics/functional/clustering/utils.py @@ -99,3 +99,73 @@ def check_cluster_labels(preds: Tensor, target: Tensor) -> None: f"Expected real, discrete values but received {preds.dtype} for" f"predictions and {target.dtype} for target labels instead." ) + + +def calcualte_pair_cluster_confusion_matrix( + preds: Optional[Tensor] = None, + target: Optional[Tensor] = None, + contingency: Optional[Tensor] = None, +) -> Tensor: + """Calculates the pair cluster confusion matrix. + + Can either be calculated from predicted cluster labels and target cluster labels or from a pre-computed + contingency matrix. The pair cluster confusion matrix is a 2x2 matrix where that defines the similarity between + two clustering by considering all pairs of samples and counting pairs that are assigned into same or different + clusters in the predicted and target clusterings. + + Note that the matrix is not symmetric. + + Inspired by: + https://scikit-learn.org/stable/modules/generated/sklearn.metrics.cluster.pair_confusion_matrix.html + + Args: + preds: predicted cluster labels + target: ground truth cluster labels + contingency: contingency matrix + + Returns: + A 2x2 tensor containing the pair cluster confusion matrix. + + Raises: + ValueError: + If neither `preds` and `target` nor `contingency` are provided. + ValueError: + If both `preds` and `target` and `contingency` are provided. + + Example: + >>> import torch + >>> from torchmetrics.functional.clustering.utils import calcualte_pair_cluster_confusion_matrix + >>> preds = torch.tensor([0, 0, 1, 1]) + >>> target = torch.tensor([1, 1, 0, 0]) + >>> calcualte_pair_cluster_confusion_matrix(preds, target) + tensor([[8, 0], + [0, 4]]) + >>> preds = torch.tensor([0, 0, 1, 2]) + >>> target = torch.tensor([0, 0, 1, 1]) + >>> calcualte_pair_cluster_confusion_matrix(preds, target) + tensor([[8, 2], + [0, 2]]) + + """ + if preds is None and target is None and contingency is None: + raise ValueError("Must provide either `preds` and `target` or `contingency`.") + if preds is not None and target is not None and contingency is not None: + raise ValueError("Must provide either `preds` and `target` or `contingency`, not both.") + + if preds is not None and target is not None: + contingency = calculate_contingency_matrix(preds, target) + + if contingency is None: + raise ValueError("Must provide `contingency` if `preds` and `target` are not provided.") + + n_samples = contingency.sum() + n_c = contingency.sum(dim=1) + n_k = contingency.sum(dim=0) + sum_squared = (contingency**2).sum() + + pair_matrix = torch.zeros(2, 2, dtype=contingency.dtype, device=contingency.device) + pair_matrix[1, 1] = sum_squared - n_samples + pair_matrix[1, 0] = (contingency * n_k).sum() - sum_squared + pair_matrix[0, 1] = (contingency.T * n_c).sum() - sum_squared + pair_matrix[0, 0] = n_samples**2 - pair_matrix[0, 1] - pair_matrix[1, 0] - sum_squared + return pair_matrix diff --git a/tests/unittests/__init__.py b/tests/unittests/__init__.py index bed0bda5147..e77f74161bb 100644 --- a/tests/unittests/__init__.py +++ b/tests/unittests/__init__.py @@ -3,7 +3,16 @@ import numpy import torch -from unittests.conftest import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, NUM_PROCESSES, THRESHOLD, setup_ddp +from unittests.conftest import ( + BATCH_SIZE, + EXTRA_DIM, + NUM_BATCHES, + NUM_CLASSES, + NUM_PROCESSES, + THRESHOLD, + setup_ddp, + skip_on_running_out_of_memory, +) # adding compatibility for numpy >= 1.24 for tp_name, tp_ins in [("object", object), ("bool", bool), ("int", int), ("float", float)]: @@ -25,4 +34,5 @@ "NUM_PROCESSES", "THRESHOLD", "setup_ddp", + "skip_on_running_out_of_memory", ] diff --git a/tests/unittests/clustering/test_rand_score.py b/tests/unittests/clustering/test_rand_score.py new file mode 100644 index 00000000000..d00fd421d34 --- /dev/null +++ b/tests/unittests/clustering/test_rand_score.py @@ -0,0 +1,94 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple + +import pytest +import torch +from sklearn.metrics import rand_score as sklearn_rand_score +from torchmetrics.clustering.rand_score import RandScore +from torchmetrics.functional.clustering.rand_score import rand_score + +from unittests import BATCH_SIZE, NUM_BATCHES +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester + +seed_all(42) + +Input = namedtuple("Input", ["preds", "target"]) +NUM_CLASSES = 10 + +_single_target_inputs1 = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_single_target_inputs2 = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_float_inputs = Input( + preds=torch.rand((NUM_BATCHES, BATCH_SIZE)), + target=torch.rand((NUM_BATCHES, BATCH_SIZE)), +) + + +@pytest.mark.parametrize( + "preds, target", + [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + (_single_target_inputs2.preds, _single_target_inputs2.target), + ], +) +class TestRandScore(MetricTester): + """Test class for `RandScore` metric.""" + + atol = 1e-5 + + @pytest.mark.parametrize("ddp", [True, False]) + def test_rand_score(self, preds, target, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=RandScore, + reference_metric=sklearn_rand_score, + ) + + def test_rand_score_functional(self, preds, target): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=rand_score, + reference_metric=sklearn_rand_score, + ) + + +def test_rand_score_functional_raises_invalid_task(): + """Check that metric rejects continuous-valued inputs.""" + preds, target = _float_inputs + with pytest.raises(ValueError, match=r"Expected *"): + rand_score(preds, target) + + +@pytest.mark.parametrize( + ("preds", "target"), + [(_single_target_inputs1.preds, _single_target_inputs1.target)], +) +def test_rand_score_functional_is_symmetric(preds, target): + """Check that the metric funtional is symmetric.""" + for p, t in zip(preds, target): + assert torch.allclose(rand_score(p, t), rand_score(t, p)) diff --git a/tests/unittests/clustering/test_utils.py b/tests/unittests/clustering/test_utils.py index 95ee1a6a4a7..571ee5614ee 100644 --- a/tests/unittests/clustering/test_utils.py +++ b/tests/unittests/clustering/test_utils.py @@ -17,7 +17,11 @@ import pytest import torch from sklearn.metrics.cluster import contingency_matrix as sklearn_contingency_matrix -from torchmetrics.functional.clustering.utils import calculate_contingency_matrix +from sklearn.metrics.cluster import pair_confusion_matrix as sklearn_pair_confusion_matrix +from torchmetrics.functional.clustering.utils import ( + calcualte_pair_cluster_confusion_matrix, + calculate_contingency_matrix, +) from unittests import BATCH_SIZE from unittests.helpers import seed_all @@ -76,3 +80,19 @@ def test_multidimensional_contingency_error(): """Check that contingency matrix is not calculated for multidimensional input.""" with pytest.raises(ValueError, match="Expected 1d*"): calculate_contingency_matrix(_multi_dim_inputs.preds, _multi_dim_inputs.target) + + +@pytest.mark.parametrize( + ("preds", "target"), + [(_sklearn_inputs.preds, _sklearn_inputs.target), (_single_dim_inputs.preds, _single_dim_inputs.target)], +) +class TestPairClusterConfusionMatrix: + """Test that implementation matches sklearns.""" + + atol = 1e-8 + + def test_pair_cluster_confusion_matrix(self, preds, target): + """Check that pair cluster confusion matrix is calculated correctly.""" + tm_res = calcualte_pair_cluster_confusion_matrix(preds, target) + sklearn_res = sklearn_pair_confusion_matrix(preds, target) + assert np.allclose(tm_res, sklearn_res, atol=self.atol) diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index 61c1b6fd864..90c53a387f7 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -14,6 +14,8 @@ import contextlib import os import sys +from functools import wraps +from typing import Any, Callable, Optional import pytest import torch @@ -69,3 +71,21 @@ def pytest_sessionfinish(): """ pytest.pool.close() pytest.pool.join() + + +def skip_on_running_out_of_memory(reason: str = "Skipping test as it ran out of memory."): + """Handle tests that sometimes runs out of memory, by simply skipping them.""" + + def test_decorator(function: Callable, *args: Any, **kwargs: Any) -> Optional[Callable]: + @wraps(function) + def run_test(*args: Any, **kwargs: Any) -> Optional[Any]: + try: + return function(*args, **kwargs) + except RuntimeError as ex: + if "DefaultCPUAllocator: not enough memory:" not in str(ex): + raise ex + pytest.skip(reason) + + return run_test + + return test_decorator diff --git a/tests/unittests/image/test_perceptual_path_length.py b/tests/unittests/image/test_perceptual_path_length.py index 8535f74a524..0f76ce51372 100644 --- a/tests/unittests/image/test_perceptual_path_length.py +++ b/tests/unittests/image/test_perceptual_path_length.py @@ -24,6 +24,7 @@ from torchmetrics.image.perceptual_path_length import PerceptualPathLength from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE +from unittests import skip_on_running_out_of_memory from unittests.helpers import seed_all seed_all(42) @@ -42,6 +43,7 @@ def test_interpolation_methods(interpolation_method): @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch_fidelity") +@skip_on_running_out_of_memory() def test_sim_net(): """Check that the similiarity network is the same as the one used in torch_fidelity.""" compare = SampleSimilarityLPIPS("sample_similarity", resize=64) @@ -113,6 +115,7 @@ def sample(self, num_samples): ({"upper_discard": 2}, "Argument `upper_discard` must be a float between 0 and 1 or `None`, but got 2"), ], ) +@skip_on_running_out_of_memory() def test_raises_error_on_wrong_arguments(argument, match): """Test that appropriate errors are raised on wrong arguments.""" with pytest.raises(ValueError, match=match): diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index b47b06da7f8..f5c7d8d4562 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -91,6 +91,7 @@ MultilabelROC, MultilabelSpecificity, ) +from torchmetrics.clustering import MutualInfoScore, RandScore from torchmetrics.detection import PanopticQuality from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio @@ -614,6 +615,8 @@ id="squad", ), pytest.param(TranslationEditRate, _text_input_3, _text_input_4, id="translation edit rate"), + pytest.param(MutualInfoScore, _nominal_input, _nominal_input, id="mutual info score"), + pytest.param(RandScore, _nominal_input, _nominal_input, id="rand score"), ], ) @pytest.mark.parametrize("num_vals", [1, 3])