diff --git a/CHANGELOG.md b/CHANGELOG.md index 97aae384402..777937c3c63 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `RetrievalMRR` metric for Information Retrieval ([#119](https://github.com/PyTorchLightning/metrics/pull/119)) +- Added `RetrievalPrecision` metric for Information Retrieval ([#119](https://github.com/PyTorchLightning/metrics/pull/119)) + + - Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110)) @@ -38,6 +41,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ) +- Added `BootStrapper` to easely calculate confidence intervals for metrics ([#101](https://github.com/PyTorchLightning/metrics/pull/101)) + + ### Changed - Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68)) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 32ff1e597d0..5f450eca1df 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -248,7 +248,14 @@ retrieval_average_precision [func] retrieval_reciprocal_rank [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: torchmetrics.functional.retrieval_reciprocal_rank :noindex: + + +retrieval_precision [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.retrieval_precision + :noindex: diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 3701ada3844..f6f6d32af86 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -334,3 +334,21 @@ RetrievalMRR .. autoclass:: torchmetrics.RetrievalMRR :noindex: + + +RetrievalPrecision +~~~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.RetrievalPrecision + :noindex: + + +******** +Wrappers +******** + +Modular wrapper metrics are not metrics in themself, but instead take a metric and alter the internal logic +of the base metric. + +.. autoclass:: torchmetrics.BootStrapper + :noindex: diff --git a/tests/functional/test_retrieval.py b/tests/functional/test_retrieval.py index 4a1bd400d20..3d5e213a5d6 100644 --- a/tests/functional/test_retrieval.py +++ b/tests/functional/test_retrieval.py @@ -7,7 +7,9 @@ from tests.helpers import seed_all from tests.retrieval.test_mrr import _reciprocal_rank as reciprocal_rank +from tests.retrieval.test_precision import _precision_at_k as precision_at_k from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision +from torchmetrics.functional.retrieval.precision import retrieval_precision from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank seed_all(1337) @@ -42,9 +44,39 @@ def test_metrics_output_values(sklearn_metric, torch_metric, size): assert torch.allclose(sk.float(), tm.float()) +@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ + [precision_at_k, retrieval_precision], +]) +@pytest.mark.parametrize("size", [1, 4, 10]) +@pytest.mark.parametrize("k", [None, 1, 4, 10]) +def test_metrics_output_values_with_k(sklearn_metric, torch_metric, size, k): + """ Compare PL metrics to sklearn version. """ + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # test results are computed correctly wrt std implementation + for i in range(6): + preds = np.random.randn(size) + target = np.random.randn(size) > 0 + + # sometimes test with integer targets + if (i % 2) == 0: + target = target.astype(np.int) + + sk = torch.tensor(sklearn_metric(target, preds, k), device=device) + tm = torch_metric(torch.tensor(preds, device=device), torch.tensor(target, device=device), k) + + # `torch_metric`s return 0 when no label is True + # while `sklearn` metrics returns NaN + if math.isnan(sk): + assert tm == 0 + else: + assert torch.allclose(sk.float(), tm.float()) + + @pytest.mark.parametrize(['torch_metric'], [ [retrieval_average_precision], [retrieval_reciprocal_rank], + [retrieval_precision] ]) def test_input_dtypes(torch_metric) -> None: """ Check wrong input dtypes are managed correctly. """ @@ -75,6 +107,7 @@ def test_input_dtypes(torch_metric) -> None: @pytest.mark.parametrize(['torch_metric'], [ [retrieval_average_precision], [retrieval_reciprocal_rank], + [retrieval_precision] ]) def test_input_shapes(torch_metric) -> None: """ Check wrong input shapes are managed correctly. """ @@ -93,3 +126,19 @@ def test_input_shapes(torch_metric) -> None: with pytest.raises(ValueError, match="`preds` and `target` must be of the same shape"): torch_metric(preds, target) + + +# test metrics using top K parameter +@pytest.mark.parametrize(['torch_metric'], [ + [retrieval_precision] +]) +@pytest.mark.parametrize('k', [-1, 1.0]) +def test_input_params(torch_metric, k) -> None: + """ Check wrong input shapes are managed correctly. """ + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # test with random tensors + preds = torch.tensor([0] * 4, device=device, dtype=torch.float) + target = torch.tensor([0] * 4, device=device, dtype=torch.int64) + with pytest.raises(ValueError, match="`k` has to be a positive integer or None"): + torch_metric(preds, target, k=k) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 76994a1b4c9..8a3c7bf899e 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -6,12 +6,13 @@ from torch import Tensor from tests.helpers import seed_all +from torchmetrics import Metric seed_all(1337) def _compute_sklearn_metric( - metric: Callable, target: List[np.ndarray], preds: List[np.ndarray], behaviour: str + metric: Callable, target: List[np.ndarray], preds: List[np.ndarray], behaviour: str, **kwargs ) -> Tensor: """ Compute metric with multiple iterations over every query predictions set. """ sk_results = [] @@ -25,7 +26,7 @@ def _compute_sklearn_metric( else: sk_results.append(0.0) else: - res = metric(b, a) + res = metric(b, a, **kwargs) sk_results.append(res) if len(sk_results) > 0: @@ -34,10 +35,15 @@ def _compute_sklearn_metric( def _test_retrieval_against_sklearn( - sklearn_metric, torch_metric, size, n_documents, query_without_relevant_docs_options + sklearn_metric: Callable, + torch_metric: Metric, + size: int, + n_documents: int, + query_without_relevant_docs_options: str, + **kwargs ) -> None: """ Compare PL metrics to standard version. """ - metric = torch_metric(query_without_relevant_docs=query_without_relevant_docs_options) + metric = torch_metric(query_without_relevant_docs=query_without_relevant_docs_options, **kwargs) shape = (size, ) indexes = [] @@ -49,7 +55,7 @@ def _test_retrieval_against_sklearn( preds.append(np.random.randn(*shape)) target.append(np.random.randn(*shape) > 0) - sk_results = _compute_sklearn_metric(sklearn_metric, target, preds, query_without_relevant_docs_options) + sk_results = _compute_sklearn_metric(sklearn_metric, target, preds, query_without_relevant_docs_options, **kwargs) sk_results = torch.tensor(sk_results) indexes_tensor = torch.cat([torch.tensor(i) for i in indexes]).long() @@ -120,3 +126,9 @@ def _test_input_shapes(torchmetric) -> None: with pytest.raises(ValueError, match="`indexes`, `preds` and `target` must be of the same shape"): metric(indexes, preds, target) + + +def _test_input_args(torchmetric: Metric, message: str, **kwargs) -> None: + """Check invalid args are managed correctly. """ + with pytest.raises(ValueError, match=message): + torchmetric(**kwargs) diff --git a/tests/retrieval/test_mrr.py b/tests/retrieval/test_mrr.py index 25cd3e88e26..1b4e2cfbc21 100644 --- a/tests/retrieval/test_mrr.py +++ b/tests/retrieval/test_mrr.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from sklearn.metrics import label_ranking_average_precision_score from tests.retrieval.helpers import _test_dtypes, _test_input_shapes, _test_retrieval_against_sklearn from torchmetrics.retrieval.mean_reciprocal_rank import RetrievalMRR @@ -7,19 +8,24 @@ def _reciprocal_rank(target: np.array, preds: np.array): """ - Implementation of reciprocal rank because couldn't find a good implementation. - `sklearn.metrics.label_ranking_average_precision_score` is similar but works in a different way - then the number of positive labels is greater than 1. + Adaptation of `sklearn.metrics.label_ranking_average_precision_score`. + Since the original sklearn metric works as RR only when the number of positive + targets is exactly 1, here we remove every positive target that is not the most + important. Remember that in RR only the positive target with the highest score is considered. """ assert target.shape == preds.shape assert len(target.shape) == 1 # works only with single dimension inputs + # going to remove T targets that are not ranked as highest + indexes = preds[target.astype(np.bool)] + if len(indexes) > 0: + target[preds != indexes.max(-1, keepdims=True)[0]] = 0 # ensure that only 1 positive label is present + if target.sum() > 0: - target = target[np.argsort(preds, axis=-1)][::-1] - rank = np.nonzero(target)[0][0] + 1 - return 1.0 / rank + # sklearn `label_ranking_average_precision_score` requires at most 2 dims + return label_ranking_average_precision_score(np.expand_dims(target, axis=0), np.expand_dims(preds, axis=0)) else: - return np.NaN + return 0.0 @pytest.mark.parametrize('size', [1, 4, 10]) diff --git a/tests/retrieval/test_precision.py b/tests/retrieval/test_precision.py new file mode 100644 index 00000000000..ed16860c926 --- /dev/null +++ b/tests/retrieval/test_precision.py @@ -0,0 +1,56 @@ +import numpy as np +import pytest + +from tests.retrieval.helpers import _test_dtypes, _test_input_args, _test_input_shapes, _test_retrieval_against_sklearn +from torchmetrics.retrieval.retrieval_precision import RetrievalPrecision + + +def _precision_at_k(target: np.array, preds: np.array, k: int = None): + """ + Didn't find a reliable implementation of Precision in Information Retrieval, so, + reimplementing here. A good explanation can be found `` + """ + assert target.shape == preds.shape + assert len(target.shape) == 1 # works only with single dimension inputs + + if k is None: + k = len(preds) + + if target.sum() > 0: + order_indexes = np.argsort(preds, axis=0)[::-1] + relevant = np.sum(target[order_indexes][:k]) + return relevant * 1.0 / k + else: + return np.NaN + + +@pytest.mark.parametrize('size', [1, 4, 10]) +@pytest.mark.parametrize('n_documents', [1, 5]) +@pytest.mark.parametrize('query_without_relevant_docs_options', ['skip', 'pos', 'neg']) +@pytest.mark.parametrize('k', [None, 1, 4, 10]) +def test_results(size, n_documents, query_without_relevant_docs_options, k): + """ Test metrics are computed correctly. """ + _test_retrieval_against_sklearn( + _precision_at_k, + RetrievalPrecision, + size, + n_documents, + query_without_relevant_docs_options, + k=k + ) + + +def test_dtypes(): + """ Check dypes are managed correctly. """ + _test_dtypes(RetrievalPrecision) + + +def test_input_shapes() -> None: + """Check inputs shapes are managed correctly. """ + _test_input_shapes(RetrievalPrecision) + + +@pytest.mark.parametrize('k', [-1, 1.0]) +def test_input_params(k) -> None: + """Check invalid args are managed correctly. """ + _test_input_args(RetrievalPrecision, "`k` has to be a positive integer or None", k=k) diff --git a/tests/wrappers/__init__.py b/tests/wrappers/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/wrappers/test_bootstrapping.py b/tests/wrappers/test_bootstrapping.py new file mode 100644 index 00000000000..afea77d4267 --- /dev/null +++ b/tests/wrappers/test_bootstrapping.py @@ -0,0 +1,108 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import operator + +import numpy as np +import pytest +import torch +from sklearn.metrics import precision_score, recall_score +from torch import Tensor + +from torchmetrics.classification import Precision, Recall +from torchmetrics.utilities import apply_to_collection +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_7 +from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler + +_preds = torch.randint(10, (10, 32)) +_target = torch.randint(10, (10, 32)) + + +class TestBootStrapper(BootStrapper): + """ For testing purpose, we subclass the bootstrapper class so we can get the exact permutation + the class is creating + """ + def update(self, *args) -> None: + self.out = [] + for idx in range(self.num_bootstraps): + size = len(args[0]) + sample_idx = _bootstrap_sampler(size, sampling_strategy=self.sampling_strategy) + new_args = apply_to_collection(args, Tensor, torch.index_select, dim=0, index=sample_idx) + self.metrics[idx].update(*new_args) + self.out.append(new_args) + + +def _sample_checker(old_samples, new_samples, op: operator, threshold: int): + found_one = False + for os in old_samples: + cond = op(os, new_samples) + if cond.sum() > threshold: + found_one = True + break + return found_one + + +@pytest.mark.parametrize("sampling_strategy", ['poisson', 'multinomial']) +def test_bootstrap_sampler(sampling_strategy): + """ make sure that the bootstrap sampler works as intended """ + old_samples = torch.randn(10, 2) + + # make sure that the new samples are only made up of old samples + idx = _bootstrap_sampler(10, sampling_strategy=sampling_strategy) + new_samples = old_samples[idx] + for ns in new_samples: + assert ns in old_samples + + found_one = _sample_checker(old_samples, new_samples, operator.eq, 2) + assert found_one, "resampling did not work because no samples were sampled twice" + + found_zero = _sample_checker(old_samples, new_samples, operator.ne, 0) + assert found_zero, "resampling did not work because all samples were atleast sampled once" + + +@pytest.mark.parametrize("sampling_strategy", ['poisson', 'multinomial']) +@pytest.mark.parametrize( + "metric, sk_metric", [[Precision(average='micro'), precision_score], [Recall(average='micro'), recall_score]] +) +def test_bootstrap(sampling_strategy, metric, sk_metric): + """ Test that the different bootstraps gets updated as we expected and that the compute method works """ + _kwargs = {'base_metric': metric, 'mean': True, 'std': True, 'raw': True, 'sampling_strategy': sampling_strategy} + if _TORCH_GREATER_EQUAL_1_7: + _kwargs.update(dict(quantile=torch.tensor([0.05, 0.95]))) + + bootstrapper = TestBootStrapper(**_kwargs) + + collected_preds = [[] for _ in range(10)] + collected_target = [[] for _ in range(10)] + for p, t in zip(_preds, _target): + bootstrapper.update(p, t) + + for i, o in enumerate(bootstrapper.out): + + collected_preds[i].append(o[0]) + collected_target[i].append(o[1]) + + collected_preds = [torch.cat(cp) for cp in collected_preds] + collected_target = [torch.cat(ct) for ct in collected_target] + + sk_scores = [sk_metric(ct, cp, average='micro') for ct, cp in zip(collected_target, collected_preds)] + + output = bootstrapper.compute() + # quantile only avaible for pytorch v1.7 and forward + if _TORCH_GREATER_EQUAL_1_7: + assert np.allclose(output['quantile'][0], np.quantile(sk_scores, 0.05)) + assert np.allclose(output['quantile'][1], np.quantile(sk_scores, 0.95)) + + assert np.allclose(output['mean'], np.mean(sk_scores)) + assert np.allclose(output['std'], np.std(sk_scores, ddof=1)) + assert np.allclose(output['raw'], sk_scores) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 05408441486..7e21b7886df 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -49,4 +49,5 @@ MeanSquaredLogError, R2Score, ) -from torchmetrics.retrieval import RetrievalMAP, RetrievalMRR # noqa: F401 E402 +from torchmetrics.retrieval import RetrievalMAP, RetrievalMRR, RetrievalPrecision # noqa: F401 E402 +from torchmetrics.wrappers import BootStrapper # noqa: F401 E402 diff --git a/torchmetrics/classification/auroc.py b/torchmetrics/classification/auroc.py index 79f19415857..52716793e4e 100644 --- a/torchmetrics/classification/auroc.py +++ b/torchmetrics/classification/auroc.py @@ -78,8 +78,7 @@ class AUROC(Metric): ValueError: If the mode of data (binary, multi-label, multi-class) changes between batches. - Example: - >>> # binary case + Example (binary case): >>> from torchmetrics import AUROC >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) >>> target = torch.tensor([0, 0, 1, 1, 1]) @@ -87,7 +86,7 @@ class AUROC(Metric): >>> auroc(preds, target) tensor(0.5000) - >>> # multiclass case + Example (multiclass case): >>> preds = torch.tensor([[0.90, 0.05, 0.05], ... [0.05, 0.90, 0.05], ... [0.05, 0.05, 0.90], diff --git a/torchmetrics/classification/average_precision.py b/torchmetrics/classification/average_precision.py index 0ecfeb3c864..968191517c3 100644 --- a/torchmetrics/classification/average_precision.py +++ b/torchmetrics/classification/average_precision.py @@ -52,8 +52,7 @@ class AveragePrecision(Metric): process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) - Example: - >>> # binary case + Example (binary case): >>> from torchmetrics import AveragePrecision >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 1]) @@ -61,7 +60,7 @@ class AveragePrecision(Metric): >>> average_precision(pred, target) tensor(1.) - >>> # multiclass case + Example (multiclass case): >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], diff --git a/torchmetrics/classification/hinge.py b/torchmetrics/classification/hinge.py index d1492fec0a8..bfb2a5ea5b7 100644 --- a/torchmetrics/classification/hinge.py +++ b/torchmetrics/classification/hinge.py @@ -60,8 +60,7 @@ class Hinge(Metric): If ``multiclass_mode`` is not: None, ``MulticlassMode.CRAMMER_SINGER``, ``"crammer-singer"``, ``MulticlassMode.ONE_VS_ALL`` or ``"one-vs-all"``. - Example: - # binary example + Example (binary case): >>> import torch >>> from torchmetrics import Hinge >>> target = torch.tensor([0, 1, 1]) @@ -70,16 +69,14 @@ class Hinge(Metric): >>> hinge(preds, target) tensor(0.3000) - - # multiclass example, default mode + Example (default / multiclass case): >>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) >>> hinge = Hinge() >>> hinge(preds, target) tensor(2.9000) - - # multiclass example, one vs all mode + Example (multiclass example, one vs all mode): >>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) >>> hinge = Hinge(multiclass_mode="one-vs-all") diff --git a/torchmetrics/classification/matthews_corrcoef.py b/torchmetrics/classification/matthews_corrcoef.py index a240d97ec60..91978e92706 100644 --- a/torchmetrics/classification/matthews_corrcoef.py +++ b/torchmetrics/classification/matthews_corrcoef.py @@ -66,7 +66,6 @@ class MatthewsCorrcoef(Metric): will be used to perform the allgather Example: - >>> from torchmetrics import MatthewsCorrcoef >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) diff --git a/torchmetrics/classification/precision_recall_curve.py b/torchmetrics/classification/precision_recall_curve.py index de781cfff15..18f11d0f956 100644 --- a/torchmetrics/classification/precision_recall_curve.py +++ b/torchmetrics/classification/precision_recall_curve.py @@ -52,8 +52,7 @@ class PrecisionRecallCurve(Metric): process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) - Example: - >>> # binary case + Example (binary case): >>> from torchmetrics import PrecisionRecallCurve >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 0]) @@ -66,7 +65,7 @@ class PrecisionRecallCurve(Metric): >>> thresholds tensor([1, 2, 3]) - >>> # multiclass case + Example (multiclass case): >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], diff --git a/torchmetrics/classification/roc.py b/torchmetrics/classification/roc.py index bd93695153e..5f25537e45d 100644 --- a/torchmetrics/classification/roc.py +++ b/torchmetrics/classification/roc.py @@ -53,7 +53,6 @@ class ROC(Metric): will be used to perform the allgather Example (binary case): - >>> from torchmetrics import ROC >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 1]) @@ -67,8 +66,6 @@ class ROC(Metric): tensor([4, 3, 2, 1, 0]) Example (multiclass case): - - >>> from torchmetrics import ROC >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05], @@ -87,8 +84,6 @@ class ROC(Metric): tensor([1.7500, 0.7500, 0.0500])] Example (multilabel case): - - >>> from torchmetrics import ROC >>> pred = torch.tensor([[0.8191, 0.3680, 0.1138], ... [0.3584, 0.7576, 0.1183], ... [0.2286, 0.3468, 0.1338], diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index cda4dacaa4d..3f0e0933c69 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -46,9 +46,9 @@ class MetricCollection(nn.ModuleDict): ValueError: If ``metrics`` is not a ``list``, ``tuple`` or a ``dict``. - Example: - >>> # input as list + Example (input as list): >>> import torch + >>> from pprint import pprint >>> from torchmetrics import MetricCollection, Accuracy, Precision, Recall >>> target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) >>> preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) @@ -58,14 +58,14 @@ class MetricCollection(nn.ModuleDict): >>> metrics(preds, target) {'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)} - >>> # input as dict + Example (input as dict): >>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'), ... 'macro_recall': Recall(num_classes=3, average='macro')}) >>> same_metric = metrics.clone() - >>> metrics(preds, target) - {'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)} - >>> same_metric(preds, target) - {'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)} + >>> pprint(metrics(preds, target)) + {'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)} + >>> pprint(same_metric(preds, target)) + {'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)} >>> metrics.persistent() """ diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 9c133dfff25..18d90b6d9ce 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -38,5 +38,6 @@ from torchmetrics.functional.regression.r2score import r2score # noqa: F401 from torchmetrics.functional.regression.ssim import ssim # noqa: F401 from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401 +from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401 from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401 from torchmetrics.functional.self_supervised import embedding_similarity # noqa: F401 diff --git a/torchmetrics/functional/classification/auroc.py b/torchmetrics/functional/classification/auroc.py index d8c8ddd7be0..25d60d27e0b 100644 --- a/torchmetrics/functional/classification/auroc.py +++ b/torchmetrics/functional/classification/auroc.py @@ -177,15 +177,14 @@ def auroc( ValueError: If ``average`` is none of ``None``, ``"macro"`` or ``"weighted"``. - Example: - >>> # binary case + Example (binary case): >>> from torchmetrics.functional import auroc >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) >>> target = torch.tensor([0, 0, 1, 1, 1]) >>> auroc(preds, target, pos_label=1) tensor(0.5000) - >>> # multiclass case + Example (multiclass case): >>> preds = torch.tensor([[0.90, 0.05, 0.05], ... [0.05, 0.90, 0.05], ... [0.05, 0.05, 0.90], diff --git a/torchmetrics/functional/classification/average_precision.py b/torchmetrics/functional/classification/average_precision.py index 543bbfb943f..6f3b9328d16 100644 --- a/torchmetrics/functional/classification/average_precision.py +++ b/torchmetrics/functional/classification/average_precision.py @@ -77,15 +77,14 @@ def average_precision( tensor with average precision. If multiclass will return list of such tensors, one for each class - Example: - >>> # binary case + Example (binary case): >>> from torchmetrics.functional import average_precision >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 1]) >>> average_precision(pred, target, pos_label=1) tensor(1.) - >>> # multiclass case + Example (multiclass case): >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], diff --git a/torchmetrics/functional/classification/hinge.py b/torchmetrics/functional/classification/hinge.py index e7d4ead5c00..35c09cedd42 100644 --- a/torchmetrics/functional/classification/hinge.py +++ b/torchmetrics/functional/classification/hinge.py @@ -154,8 +154,7 @@ def hinge( If ``multiclass_mode`` is not: None, ``MulticlassMode.CRAMMER_SINGER``, ``"crammer-singer"``, ``MulticlassMode.ONE_VS_ALL`` or ``"one-vs-all"``. - Example: - # binary example + Example (binary case): >>> import torch >>> from torchmetrics.functional import hinge >>> target = torch.tensor([0, 1, 1]) @@ -163,15 +162,13 @@ def hinge( >>> hinge(preds, target) tensor(0.3000) - - # multiclass example, default mode + Example (default / multiclass case): >>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) >>> hinge(preds, target) tensor(2.9000) - - # multiclass example, one vs all mode + Example (multiclass example, one vs all mode): >>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) >>> hinge(preds, target, multiclass_mode="one-vs-all") diff --git a/torchmetrics/functional/classification/precision_recall_curve.py b/torchmetrics/functional/classification/precision_recall_curve.py index 7d6881b23f3..68fd443c789 100644 --- a/torchmetrics/functional/classification/precision_recall_curve.py +++ b/torchmetrics/functional/classification/precision_recall_curve.py @@ -203,8 +203,7 @@ def precision_recall_curve( If the number of classes deduced from ``preds`` is not the same as the ``num_classes`` provided. - Example: - >>> # binary case + Example (binary case): >>> from torchmetrics.functional import precision_recall_curve >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 0]) @@ -216,7 +215,7 @@ def precision_recall_curve( >>> thresholds tensor([1, 2, 3]) - >>> # multiclass case + Example (multiclass case): >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], diff --git a/torchmetrics/functional/classification/roc.py b/torchmetrics/functional/classification/roc.py index b7fee4ba620..9cb8c2f5762 100644 --- a/torchmetrics/functional/classification/roc.py +++ b/torchmetrics/functional/classification/roc.py @@ -121,7 +121,6 @@ def roc( If multiclass or multilabel, this is a list of such tensors, one for each class/label. Example (binary case): - >>> from torchmetrics.functional import roc >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 1]) @@ -134,7 +133,6 @@ def roc( tensor([4, 3, 2, 1, 0]) Example (multiclass case): - >>> from torchmetrics.functional import roc >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05], @@ -153,7 +151,6 @@ def roc( tensor([1.7500, 0.7500, 0.0500])] Example (multilabel case): - >>> from torchmetrics.functional import roc >>> pred = torch.tensor([[0.8191, 0.3680, 0.1138], ... [0.3584, 0.7576, 0.1183], diff --git a/torchmetrics/functional/retrieval/__init__.py b/torchmetrics/functional/retrieval/__init__.py index 07f7c57a278..ff92f8db3c6 100644 --- a/torchmetrics/functional/retrieval/__init__.py +++ b/torchmetrics/functional/retrieval/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401 +from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401 from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401 diff --git a/torchmetrics/functional/retrieval/average_precision.py b/torchmetrics/functional/retrieval/average_precision.py index 29e6ff6914c..4e4672b91c7 100644 --- a/torchmetrics/functional/retrieval/average_precision.py +++ b/torchmetrics/functional/retrieval/average_precision.py @@ -18,7 +18,7 @@ def retrieval_average_precision(preds: Tensor, target: Tensor) -> Tensor: - r""" + """ Computes average precision (for information retrieval), as explained `here `__. @@ -31,11 +31,12 @@ def retrieval_average_precision(preds: Tensor, target: Tensor) -> Tensor: target: ground truth about each document being relevant or not. Return: - a single-value tensor with the average precision (AP) of the predictions ``preds`` wrt the labels ``target``. + a single-value tensor with the average precision (AP) of the predictions ``preds`` w.r.t. the labels ``target``. Example: - >>> preds = torch.tensor([0.2, 0.3, 0.5]) - >>> target = torch.tensor([True, False, True]) + >>> from torchmetrics.functional import retrieval_average_precision + >>> preds = tensor([0.2, 0.3, 0.5]) + >>> target = tensor([True, False, True]) >>> retrieval_average_precision(preds, target) tensor(0.8333) """ diff --git a/torchmetrics/functional/retrieval/precision.py b/torchmetrics/functional/retrieval/precision.py new file mode 100644 index 00000000000..d896697d358 --- /dev/null +++ b/torchmetrics/functional/retrieval/precision.py @@ -0,0 +1,56 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import Tensor, tensor + +from torchmetrics.utilities.checks import _check_retrieval_functional_inputs + + +def retrieval_precision(preds: Tensor, target: Tensor, k: int = None) -> Tensor: + """ + Computes the precision metric (for information retrieval), + as explained `here `__. + Precision is the fraction of relevant documents among all the retrieved documents. + + ``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``, + ``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`, + otherwise an error is raised. If you want to measure Precision@K, ``k`` must be a positive integer. + + Args: + preds: estimated probabilities of each document to be relevant. + target: ground truth about each document being relevant or not. + k: consider only the top k elements (default: None) + + Returns: + a single-value tensor with the precision (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``. + + Example: + >>> preds = tensor([0.2, 0.3, 0.5]) + >>> target = tensor([True, False, True]) + >>> retrieval_precision(preds, target, k=2) + tensor(0.5000) + """ + preds, target = _check_retrieval_functional_inputs(preds, target) + + if k is None: + k = preds.shape[-1] + + if not (isinstance(k, int) and k > 0): + raise ValueError("`k` has to be a positive integer or None") + + if target.sum() == 0: + return tensor(0.0, device=preds.device) + + relevant = target[torch.argsort(preds, dim=-1, descending=True)][:k].sum().float() + return relevant / k diff --git a/torchmetrics/functional/retrieval/reciprocal_rank.py b/torchmetrics/functional/retrieval/reciprocal_rank.py index a2a6cedce2a..3daed08ac33 100644 --- a/torchmetrics/functional/retrieval/reciprocal_rank.py +++ b/torchmetrics/functional/retrieval/reciprocal_rank.py @@ -18,7 +18,7 @@ def retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor: - r""" + """ Computes reciprocal rank (for information retrieval), as explained `here `__. @@ -34,6 +34,7 @@ def retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor: a single-value tensor with the reciprocal rank (RR) of the predictions ``preds`` wrt the labels ``target``. Example: + >>> from torchmetrics.functional import retrieval_reciprocal_rank >>> preds = torch.tensor([0.2, 0.3, 0.5]) >>> target = torch.tensor([False, True, False]) >>> retrieval_reciprocal_rank(preds, target) diff --git a/torchmetrics/retrieval/__init__.py b/torchmetrics/retrieval/__init__.py index 3f1d96afe1c..5af6c2e2b70 100644 --- a/torchmetrics/retrieval/__init__.py +++ b/torchmetrics/retrieval/__init__.py @@ -14,3 +14,4 @@ from torchmetrics.retrieval.mean_average_precision import RetrievalMAP # noqa: F401 from torchmetrics.retrieval.mean_reciprocal_rank import RetrievalMRR # noqa: F401 from torchmetrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401 +from torchmetrics.retrieval.retrieval_precision import RetrievalPrecision # noqa: F401 diff --git a/torchmetrics/retrieval/mean_average_precision.py b/torchmetrics/retrieval/mean_average_precision.py index aef5feaa351..22e29a4eda5 100644 --- a/torchmetrics/retrieval/mean_average_precision.py +++ b/torchmetrics/retrieval/mean_average_precision.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch -from torch import Tensor +from torch import Tensor, tensor from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision from torchmetrics.retrieval.retrieval_metric import RetrievalMetric @@ -21,7 +20,7 @@ class RetrievalMAP(RetrievalMetric): """ Computes `Mean Average Precision - `_. + `__. Works with binary target data. Accepts float predictions from a model output. @@ -33,8 +32,8 @@ class RetrievalMAP(RetrievalMetric): ``indexes``, ``preds`` and ``target`` must have the same dimension. ``indexes`` indicate to which query a prediction belongs. - Predictions will be first grouped by ``indexes`` and then MAP will be computed as the mean - of the Average Precisions over each query. + Predictions will be first grouped by ``indexes`` and then `MAP` will be computed as the mean + of the `Average Precisions` over each query. Args: query_without_relevant_docs: @@ -60,9 +59,9 @@ class RetrievalMAP(RetrievalMetric): Example: >>> from torchmetrics import RetrievalMAP - >>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1]) - >>> preds = torch.tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) - >>> target = torch.tensor([False, False, True, False, True, False, True]) + >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) + >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) + >>> target = tensor([False, False, True, False, True, False, True]) >>> map = RetrievalMAP() >>> map(indexes, preds, target) tensor(0.7917) diff --git a/torchmetrics/retrieval/mean_reciprocal_rank.py b/torchmetrics/retrieval/mean_reciprocal_rank.py index 67923830aae..7d649700a49 100644 --- a/torchmetrics/retrieval/mean_reciprocal_rank.py +++ b/torchmetrics/retrieval/mean_reciprocal_rank.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch -from torch import Tensor +from torch import Tensor, tensor from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank from torchmetrics.retrieval.retrieval_metric import RetrievalMetric @@ -21,7 +20,7 @@ class RetrievalMRR(RetrievalMetric): """ Computes `Mean Reciprocal Rank - `_. + `__. Works with binary target data. Accepts float predictions from a model output. @@ -33,8 +32,8 @@ class RetrievalMRR(RetrievalMetric): ``indexes``, ``preds`` and ``target`` must have the same dimension. ``indexes`` indicate to which query a prediction belongs. - Predictions will be first grouped by ``indexes`` and then MRR will be computed as the mean - of the Reciprocal Rank over each query. + Predictions will be first grouped by ``indexes`` and then `MRR` will be computed as the mean + of the `Reciprocal Rank` over each query. Args: query_without_relevant_docs: @@ -61,9 +60,9 @@ class RetrievalMRR(RetrievalMetric): Example: >>> from torchmetrics import RetrievalMRR - >>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1]) - >>> preds = torch.tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) - >>> target = torch.tensor([False, False, True, False, True, False, True]) + >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) + >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) + >>> target = tensor([False, False, True, False, True, False, True]) >>> mrr = RetrievalMRR() >>> mrr(indexes, preds, target) tensor(0.7500) diff --git a/torchmetrics/retrieval/retrieval_precision.py b/torchmetrics/retrieval/retrieval_precision.py new file mode 100644 index 00000000000..c82cecc97d1 --- /dev/null +++ b/torchmetrics/retrieval/retrieval_precision.py @@ -0,0 +1,99 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +from torch import Tensor, tensor + +from torchmetrics.functional.retrieval.precision import retrieval_precision +from torchmetrics.retrieval.retrieval_metric import IGNORE_IDX, RetrievalMetric + + +class RetrievalPrecision(RetrievalMetric): + """ + Computes `Precision + `__. + + Works with binary target data. Accepts float predictions from a model output. + + Forward accepts: + + - ``indexes`` (long tensor): ``(N, ...)`` + - ``preds`` (float tensor): ``(N, ...)`` + - ``target`` (long or bool tensor): ``(N, ...)`` + + ``indexes``, ``preds`` and ``target`` must have the same dimension. + ``indexes`` indicate to which query a prediction belongs. + Predictions will be first grouped by ``indexes`` and then `Precision` will be computed as the mean + of the `Precision` over each query. + + Args: + query_without_relevant_docs: + Specify what to do with queries that do not have at least a positive ``target``. Choose from: + + - ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned + - ``'error'``: raise a ``ValueError`` + - ``'pos'``: score on those queries is counted as ``1.0`` + - ``'neg'``: score on those queries is counted as ``0.0`` + + exclude: + Do not take into account predictions where the ``target`` is equal to this value. default `-100` + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects + the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + k: consider only the top k elements for each query. default: None + + Example: + >>> from torchmetrics import RetrievalPrecision + >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) + >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) + >>> target = tensor([False, False, True, False, True, False, True]) + >>> p2 = RetrievalPrecision(k=2) + >>> p2(indexes, preds, target) + tensor(0.5000) + """ + + def __init__( + self, + query_without_relevant_docs: str = 'skip', + exclude: int = IGNORE_IDX, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + k: int = None + ): + super().__init__( + query_without_relevant_docs=query_without_relevant_docs, + exclude=exclude, + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn + ) + + if (k is not None) and not (isinstance(k, int) and k > 0): + raise ValueError("`k` has to be a positive integer or None") + self.k = k + + def _metric(self, preds: Tensor, target: Tensor) -> Tensor: + valid_indexes = (target != self.exclude) + return retrieval_precision(preds[valid_indexes], target[valid_indexes], k=self.k) diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index d6434186db4..1ee0bc72367 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -242,7 +242,6 @@ def get_group_indexes(idx: Tensor) -> List[Tensor]: A list of integer `torch.Tensor`s Example: - >>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1]) >>> groups = get_group_indexes(indexes) >>> groups diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index bae3e35d211..c4949a67d17 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Import utilities""" +import operator from distutils.version import LooseVersion from importlib import import_module from importlib.util import find_spec -import torch from pkg_resources import DistributionNotFound @@ -60,7 +61,8 @@ def _compare_version(package: str, op, version) -> bool: return op(pkg_version, LooseVersion(version)) -_TORCH_LOWER_1_4 = LooseVersion(torch.__version__) < LooseVersion("1.4.0") -_TORCH_LOWER_1_5 = LooseVersion(torch.__version__) < LooseVersion("1.5.0") -_TORCH_LOWER_1_6 = LooseVersion(torch.__version__) < LooseVersion("1.6.0") -_TORCH_GREATER_EQUAL_1_6 = LooseVersion(torch.__version__) >= LooseVersion("1.6.0") +_TORCH_LOWER_1_4 = _compare_version("torch", operator.lt, "1.4.0") +_TORCH_LOWER_1_5 = _compare_version("torch", operator.lt, "1.5.0") +_TORCH_LOWER_1_6 = _compare_version("torch", operator.lt, "1.6.0") +_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0") +_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0") diff --git a/torchmetrics/wrappers/__init__.py b/torchmetrics/wrappers/__init__.py new file mode 100644 index 00000000000..4f506ea4da3 --- /dev/null +++ b/torchmetrics/wrappers/__init__.py @@ -0,0 +1,14 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401 diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py new file mode 100644 index 00000000000..1bd8be2040b --- /dev/null +++ b/torchmetrics/wrappers/bootstrapping.py @@ -0,0 +1,174 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from copy import deepcopy +from typing import Any, Callable, Dict, Optional, Union + +import torch +from torch import Tensor, nn + +from torchmetrics.metric import Metric +from torchmetrics.utilities import apply_to_collection +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_7 + + +def _bootstrap_sampler( + size: int, + sampling_strategy: str = 'poisson' +) -> Tensor: + """ Resample a tensor along its first dimension with replacement + Args: + size: number of samples + sampling_strategy: the strategy to use for sampling, either ``'poisson'`` or ``'multinomial'`` + generator: a instance of ``torch.Generator`` that controls the sampling + + Returns: + resampled tensor + + """ + if sampling_strategy == 'poisson': + p = torch.distributions.Poisson(1) + n = p.sample((size,)) + return torch.arange(size).repeat_interleave(n.long(), dim=0) + elif sampling_strategy == 'multinomial': + idx = torch.multinomial( + torch.ones(size), + num_samples=size, + replacement=True + ) + return idx + raise ValueError('Unknown sampling strategy') + + +class BootStrapper(Metric): + + def __init__( + self, + base_metric: Metric, + num_bootstraps: int = 10, + mean: bool = True, + std: bool = True, + quantile: Optional[Union[float, Tensor]] = None, + raw: bool = False, + sampling_strategy: str = 'poisson', + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None + ) -> None: + r""" + Use to turn a metric into a `bootstrapped `_ + metric that can automate the process of getting confidence intervals for metric values. This wrapper + class basically keeps multiple copies of the same base metric in memory and whenever ``update`` or + ``forward`` is called, all input tensors are resampled (with replacement) along the first dimension. + + Args: + base_metric: + base metric class to wrap + num_bootstraps: + number of copies to make of the base metric for bootstrapping + mean: + if ``True`` return the mean of the bootstraps + std: + if ``True`` return the standard diviation of the bootstraps + quantile: + if given, returns the quantile of the bootstraps. Can only be used with + pytorch version 1.6 or higher + raw: + if ``True``, return all bootstrapped values + sampling_strategy: + Determines how to produce bootstrapped samplings. Either ``'poisson'`` or ``multinomial``. + If ``'possion'`` is chosen, the number of times each sample will be included in the bootstrap + will be given by :math:`n\sim Poisson(\lambda=1)`, which approximates the true bootstrap distribution + when the number of samples is large. If ``'multinomial'`` is chosen, we will apply true bootstrapping + at the batch level to approximate bootstrapping over the hole dataset. + compute_on_step: + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step + process_group: + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather. + + Example:: + >>> from pprint import pprint + >>> from torchmetrics import Accuracy, BootStrapper + >>> _ = torch.manual_seed(123) + >>> base_metric = Accuracy() + >>> bootstrap = BootStrapper(base_metric, num_bootstraps=20) + >>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,))) + >>> output = bootstrap.compute() + >>> pprint(output) + {'mean': tensor(0.2205), 'std': tensor(0.0859)} + + """ + super().__init__(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn) + if not isinstance(base_metric, Metric): + raise ValueError( + "Expected base metric to be an instance of torchmetrics.Metric" + f" but received {base_metric}" + ) + + self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_bootstraps)]) + self.num_bootstraps = num_bootstraps + + self.mean = mean + self.std = std + if quantile is not None and not _TORCH_GREATER_EQUAL_1_7: + raise ValueError('quantile argument can only be used with pytorch v1.7 or higher') + self.quantile = quantile + self.raw = raw + + allowed_sampling = ('poisson', 'multinomial') + if sampling_strategy not in allowed_sampling: + raise ValueError( + f"Expected argument ``sampling_strategy`` to be one of {allowed_sampling}" + f" but recieved {sampling_strategy}" + ) + self.sampling_strategy = sampling_strategy + + def update(self, *args: Any, **kwargs: Any) -> None: + """ Updates the state of the base metric. Any tensor passed in will be bootstrapped along dimension 0 """ + for idx in range(self.num_bootstraps): + args_sizes = apply_to_collection(args, Tensor, len) + kwargs_sizes = list(apply_to_collection(kwargs, Tensor, len)) + if len(args_sizes) > 0: + size = args_sizes[0] + elif len(kwargs_sizes) > 0: + size = kwargs_sizes[0] + else: + raise ValueError('None of the input contained tensors, so could not determine the sampling size') + sample_idx = _bootstrap_sampler(size, sampling_strategy=self.sampling_strategy) + new_args = apply_to_collection(args, Tensor, torch.index_select, dim=0, index=sample_idx) + new_kwargs = apply_to_collection(kwargs, Tensor, torch.index_select, dim=0, index=sample_idx) + self.metrics[idx].update(*new_args, **new_kwargs) + + def compute(self) -> Dict[str, Tensor]: + """ Computes the bootstrapped metric values. Allways returns a dict of tensors, which can contain the + following keys: ``mean``, ``std``, ``quantile`` and ``raw`` depending on how the class was initialized + """ + computed_vals = torch.stack([m.compute() for m in self.metrics], dim=0) + output_dict = {} + if self.mean: + output_dict['mean'] = computed_vals.mean(dim=0) + if self.std: + output_dict['std'] = computed_vals.std(dim=0) + if self.quantile is not None: + output_dict['quantile'] = torch.quantile(computed_vals, self.quantile) + if self.raw: + output_dict['raw'] = computed_vals + return output_dict