diff --git a/CHANGELOG.md b/CHANGELOG.md index 2bb38b37cf4..169e0e4597d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added Short Term Objective Intelligibility (`STOI`) ([#353](https://github.com/PyTorchLightning/metrics/issues/353)) +- Added `RetrievalHitRate` metric to retrieval package ([#576](https://github.com/PyTorchLightning/metrics/pull/576)) + + ### Changed - `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 2b4708ecea0..fbb582e3fc0 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -405,6 +405,13 @@ retrieval_normalized_dcg [func] .. autofunction:: torchmetrics.functional.retrieval_normalized_dcg :noindex: + +retrieval_hit_rate [func] +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.retrieval_hit_rate + :noindex: + **** Text **** diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index a6566018384..68bb911bbb1 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -574,6 +574,13 @@ RetrievalNormalizedDCG .. autoclass:: torchmetrics.RetrievalNormalizedDCG :noindex: + +RetrievalHitRate +~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.RetrievalHitRate + :noindex: + **** Text **** diff --git a/tests/retrieval/test_hit_rate.py b/tests/retrieval/test_hit_rate.py new file mode 100644 index 00000000000..d68badb9af5 --- /dev/null +++ b/tests/retrieval/test_hit_rate.py @@ -0,0 +1,147 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest +from torch import Tensor + +from tests.helpers import seed_all +from tests.retrieval.helpers import ( + RetrievalMetricTester, + _concat_tests, + _default_metric_class_input_arguments, + _default_metric_functional_input_arguments, + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_k, + _errors_test_class_metric_parameters_no_pos_target, + _errors_test_functional_metric_parameters_default, + _errors_test_functional_metric_parameters_k, +) +from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate +from torchmetrics.retrieval.retrieval_hit_rate import RetrievalHitRate + +seed_all(42) + + +def _hit_rate_at_k(target: np.ndarray, preds: np.ndarray, k: int = None): + """Didn't find a reliable implementation of Hit Rate in Information Retrieval, so, reimplementing here.""" + assert target.shape == preds.shape + assert len(target.shape) == 1 # works only with single dimension inputs + + if k is None: + k = len(preds) + + if target.sum() > 0: + order_indexes = np.argsort(preds, axis=0)[::-1] + relevant = np.sum(target[order_indexes][:k]) + return float(relevant > 0.0) + return np.NaN + + +class TestHitRate(RetrievalMetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize("k", [None, 1, 4, 10]) + @pytest.mark.parametrize(**_default_metric_class_input_arguments) + def test_class_metric( + self, + ddp: bool, + indexes: Tensor, + preds: Tensor, + target: Tensor, + dist_sync_on_step: bool, + empty_target_action: str, + k: int, + ): + metric_args = {"empty_target_action": empty_target_action, "k": k} + + self.run_class_metric_test( + ddp=ddp, + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalHitRate, + sk_metric=_hit_rate_at_k, + dist_sync_on_step=dist_sync_on_step, + metric_args=metric_args, + ) + + @pytest.mark.parametrize(**_default_metric_functional_input_arguments) + @pytest.mark.parametrize("k", [None, 1, 4, 10]) + def test_functional_metric(self, preds: Tensor, target: Tensor, k: int): + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=retrieval_hit_rate, + sk_metric=_hit_rate_at_k, + metric_args={}, + k=k, + ) + + @pytest.mark.parametrize(**_default_metric_class_input_arguments) + def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor): + self.run_precision_test_cpu( + indexes=indexes, + preds=preds, + target=target, + metric_module=RetrievalHitRate, + metric_functional=retrieval_hit_rate, + ) + + @pytest.mark.parametrize(**_default_metric_class_input_arguments) + def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): + self.run_precision_test_gpu( + indexes=indexes, + preds=preds, + target=target, + metric_module=RetrievalHitRate, + metric_functional=retrieval_hit_rate, + ) + + @pytest.mark.parametrize( + **_concat_tests( + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_no_pos_target, + _errors_test_class_metric_parameters_k, + ) + ) + def test_arguments_class_metric( + self, indexes: Tensor, preds: Tensor, target: Tensor, message: str, metric_args: dict + ): + self.run_metric_class_arguments_test( + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalHitRate, + message=message, + metric_args=metric_args, + exception_type=ValueError, + kwargs_update={}, + ) + + @pytest.mark.parametrize( + **_concat_tests( + _errors_test_functional_metric_parameters_default, + _errors_test_functional_metric_parameters_k, + ) + ) + def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, message: str, metric_args: dict): + self.run_functional_metric_arguments_test( + preds=preds, + target=target, + metric_functional=retrieval_hit_rate, + message=message, + exception_type=ValueError, + kwargs_update=metric_args, + ) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 3d65a6eb8bc..8ab7fe64c58 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -57,6 +57,7 @@ ) from torchmetrics.retrieval import ( # noqa: E402 RetrievalFallOut, + RetrievalHitRate, RetrievalMAP, RetrievalMRR, RetrievalNormalizedDCG, @@ -116,6 +117,7 @@ "R2Score", "Recall", "RetrievalFallOut", + "RetrievalHitRate", "RetrievalMAP", "RetrievalMRR", "RetrievalNormalizedDCG", diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index a25500a5642..911857a7ced 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -58,6 +58,7 @@ from torchmetrics.functional.regression.tweedie_deviance import tweedie_deviance_score from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out +from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg from torchmetrics.functional.retrieval.precision import retrieval_precision from torchmetrics.functional.retrieval.recall import retrieval_recall @@ -114,6 +115,7 @@ "recall", "retrieval_average_precision", "retrieval_fall_out", + "retrieval_hit_rate", "retrieval_normalized_dcg", "retrieval_precision", "retrieval_recall", diff --git a/torchmetrics/functional/retrieval/__init__.py b/torchmetrics/functional/retrieval/__init__.py index 3e2eb4cbb23..0b3d2caa420 100644 --- a/torchmetrics/functional/retrieval/__init__.py +++ b/torchmetrics/functional/retrieval/__init__.py @@ -14,6 +14,7 @@ from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401 from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out # noqa: F401 +from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate # noqa: F401 from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg # noqa: F401 from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401 from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401 diff --git a/torchmetrics/functional/retrieval/fall_out.py b/torchmetrics/functional/retrieval/fall_out.py index 304a1158168..0d3c898fe77 100644 --- a/torchmetrics/functional/retrieval/fall_out.py +++ b/torchmetrics/functional/retrieval/fall_out.py @@ -11,13 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch from torch import Tensor, tensor from torchmetrics.utilities.checks import _check_retrieval_functional_inputs -def retrieval_fall_out(preds: Tensor, target: Tensor, k: int = None) -> Tensor: +def retrieval_fall_out(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor: """Computes the Fall-out (for information retrieval), as explained in `IR Fall-out`_ Fall-out is the fraction of non-relevant documents retrieved among all the non-relevant documents. @@ -28,11 +30,15 @@ def retrieval_fall_out(preds: Tensor, target: Tensor, k: int = None) -> Tensor: Args: preds: estimated probabilities of each document to be relevant. target: ground truth about each document being relevant or not. - k: consider only the top k elements (default: None) + k: consider only the top k elements (default: None, which considers them all) Returns: a single-value tensor with the fall-out (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``. + Raises: + ValueError: + If ``k`` parameter is not `None` or an integer larger than 0 + Example: >>> from torchmetrics.functional import retrieval_fall_out >>> preds = tensor([0.2, 0.3, 0.5]) diff --git a/torchmetrics/functional/retrieval/hit_rate.py b/torchmetrics/functional/retrieval/hit_rate.py new file mode 100644 index 00000000000..2abfce61446 --- /dev/null +++ b/torchmetrics/functional/retrieval/hit_rate.py @@ -0,0 +1,57 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from torch import Tensor, tensor + +from torchmetrics.utilities.checks import _check_retrieval_functional_inputs + + +def retrieval_hit_rate(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor: + """Computes the hit rate (for information retrieval). The hit rate is 1.0 if there is at least one relevant + document among all the top `k` retrieved documents. + + ``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``, + ``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`, + otherwise an error is raised. If you want to measure HitRate@K, ``k`` must be a positive integer. + + Args: + preds: estimated probabilities of each document to be relevant. + target: ground truth about each document being relevant or not. + k: consider only the top k elements (default: None, which considers them all) + + Returns: + a single-value tensor with the hit rate (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``. + + Raises: + ValueError: + If ``k`` parameter is not `None` or an integer larger than 0 + + Example: + >>> preds = tensor([0.2, 0.3, 0.5]) + >>> target = tensor([True, False, True]) + >>> retrieval_hit_rate(preds, target, k=2) + tensor(1.) + """ + preds, target = _check_retrieval_functional_inputs(preds, target) + + if k is None: + k = preds.shape[-1] + + if not (isinstance(k, int) and k > 0): + raise ValueError("`k` has to be a positive integer or None") + + relevant = target[torch.argsort(preds, dim=-1, descending=True)][:k].sum() + return (relevant > 0).float() diff --git a/torchmetrics/functional/retrieval/ndcg.py b/torchmetrics/functional/retrieval/ndcg.py index 63effa64d1c..1675f570b6d 100644 --- a/torchmetrics/functional/retrieval/ndcg.py +++ b/torchmetrics/functional/retrieval/ndcg.py @@ -35,11 +35,15 @@ def retrieval_normalized_dcg(preds: Tensor, target: Tensor, k: Optional[int] = N Args: preds: estimated probabilities of each document to be relevant. target: ground truth about each document relevance. - k: consider only the top k elements (default: None) + k: consider only the top k elements (default: None, which considers them all) Return: a single-value tensor with the nDCG of the predictions ``preds`` w.r.t. the labels ``target``. + Raises: + ValueError: + If ``k`` parameter is not `None` or an integer larger than 0 + Example: >>> from torchmetrics.functional import retrieval_normalized_dcg >>> preds = torch.tensor([.1, .2, .3, 4, 70]) diff --git a/torchmetrics/functional/retrieval/precision.py b/torchmetrics/functional/retrieval/precision.py index 4a9fb1cba8a..327f4daadab 100644 --- a/torchmetrics/functional/retrieval/precision.py +++ b/torchmetrics/functional/retrieval/precision.py @@ -30,11 +30,15 @@ def retrieval_precision(preds: Tensor, target: Tensor, k: Optional[int] = None) Args: preds: estimated probabilities of each document to be relevant. target: ground truth about each document being relevant or not. - k: consider only the top k elements (default: None) + k: consider only the top k elements (default: None, which considers them all) Returns: a single-value tensor with the precision (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``. + Raises: + ValueError: + If ``k`` parameter is not `None` or an integer larger than 0 + Example: >>> preds = tensor([0.2, 0.3, 0.5]) >>> target = tensor([True, False, True]) diff --git a/torchmetrics/functional/retrieval/recall.py b/torchmetrics/functional/retrieval/recall.py index 8af52a706ab..bebfbcf37e4 100644 --- a/torchmetrics/functional/retrieval/recall.py +++ b/torchmetrics/functional/retrieval/recall.py @@ -30,11 +30,15 @@ def retrieval_recall(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Args: preds: estimated probabilities of each document to be relevant. target: ground truth about each document being relevant or not. - k: consider only the top k elements (default: None) + k: consider only the top k elements (default: None, which considers them all) Returns: a single-value tensor with the recall (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``. + Raises: + ValueError: + If ``k`` parameter is not `None` or an integer larger than 0 + Example: >>> from torchmetrics.functional import retrieval_recall >>> preds = tensor([0.2, 0.3, 0.5]) diff --git a/torchmetrics/retrieval/__init__.py b/torchmetrics/retrieval/__init__.py index 64be34ae813..a245edf8163 100644 --- a/torchmetrics/retrieval/__init__.py +++ b/torchmetrics/retrieval/__init__.py @@ -14,6 +14,7 @@ from torchmetrics.retrieval.mean_average_precision import RetrievalMAP # noqa: F401 from torchmetrics.retrieval.mean_reciprocal_rank import RetrievalMRR # noqa: F401 from torchmetrics.retrieval.retrieval_fallout import RetrievalFallOut # noqa: F401 +from torchmetrics.retrieval.retrieval_hit_rate import RetrievalHitRate # noqa: F401 from torchmetrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401 from torchmetrics.retrieval.retrieval_ndcg import RetrievalNormalizedDCG # noqa: F401 from torchmetrics.retrieval.retrieval_precision import RetrievalPrecision # noqa: F401 diff --git a/torchmetrics/retrieval/retrieval_fallout.py b/torchmetrics/retrieval/retrieval_fallout.py index b99fde3bfa7..b7ea1c8cc5d 100644 --- a/torchmetrics/retrieval/retrieval_fallout.py +++ b/torchmetrics/retrieval/retrieval_fallout.py @@ -46,6 +46,7 @@ class RetrievalFallOut(RetrievalMetric): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` + k: consider only the top k elements for each query (default: None, which considers them all) compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -57,7 +58,10 @@ class RetrievalFallOut(RetrievalMetric): dist_sync_fn: Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. default: None - k: consider only the top k elements for each query. default: None + + Raises: + ValueError: + If ``k`` parameter is not `None` or an integer larger than 0 Example: >>> from torchmetrics import RetrievalFallOut @@ -74,11 +78,11 @@ class RetrievalFallOut(RetrievalMetric): def __init__( self, empty_target_action: str = "pos", + k: int = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, - k: int = None, ) -> None: super().__init__( empty_target_action=empty_target_action, diff --git a/torchmetrics/retrieval/retrieval_hit_rate.py b/torchmetrics/retrieval/retrieval_hit_rate.py new file mode 100644 index 00000000000..5ec1249402a --- /dev/null +++ b/torchmetrics/retrieval/retrieval_hit_rate.py @@ -0,0 +1,98 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +from torch import Tensor, tensor + +from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate +from torchmetrics.retrieval.retrieval_metric import RetrievalMetric + + +class RetrievalHitRate(RetrievalMetric): + """Computes `IR HitRate`. + + Works with binary target data. Accepts float predictions from a model output. + + Forward accepts: + + - ``preds`` (float tensor): ``(N, ...)`` + - ``target`` (long or bool tensor): ``(N, ...)`` + - ``indexes`` (long tensor): ``(N, ...)`` + + ``indexes``, ``preds`` and ``target`` must have the same dimension. + ``indexes`` indicate to which query a prediction belongs. + Predictions will be first grouped by ``indexes`` and then the `Hit Rate` will be computed as the mean + of the `Hit Rate` over each query. + + Args: + empty_target_action: + Specify what to do with queries that do not have at least a positive ``target``. Choose from: + + - ``'neg'``: those queries count as ``0.0`` (default) + - ``'pos'``: those queries count as ``1.0`` + - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned + - ``'error'``: raise a ``ValueError`` + + k: consider only the top k elements for each query (default: None, which considers them all) + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects + the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + + Raises: + ValueError: + If ``k`` parameter is not `None` or an integer larger than 0 + + Example: + >>> from torchmetrics import RetrievalHitRate + >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) + >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) + >>> target = tensor([True, False, False, False, True, False, True]) + >>> hr2 = RetrievalHitRate(k=2) + >>> hr2(preds, target, indexes=indexes) + tensor(0.5000) + """ + + higher_is_better = True + + def __init__( + self, + empty_target_action: str = "neg", + k: int = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ) -> None: + super().__init__( + empty_target_action=empty_target_action, + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + if (k is not None) and not (isinstance(k, int) and k > 0): + raise ValueError("`k` has to be a positive integer or None") + self.k = k + + def _metric(self, preds: Tensor, target: Tensor) -> Tensor: + return retrieval_hit_rate(preds, target, k=self.k) diff --git a/torchmetrics/retrieval/retrieval_ndcg.py b/torchmetrics/retrieval/retrieval_ndcg.py index 9ddd87b1a9b..98e3b0866f4 100644 --- a/torchmetrics/retrieval/retrieval_ndcg.py +++ b/torchmetrics/retrieval/retrieval_ndcg.py @@ -44,6 +44,7 @@ class RetrievalNormalizedDCG(RetrievalMetric): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` + k: consider only the top k elements for each query (default: None, which considers them all) compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -55,7 +56,10 @@ class RetrievalNormalizedDCG(RetrievalMetric): dist_sync_fn: Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. default: None - k: consider only the top k elements for each query. default: None + + Raises: + ValueError: + If ``k`` parameter is not `None` or an integer larger than 0 Example: >>> from torchmetrics import RetrievalNormalizedDCG @@ -72,11 +76,11 @@ class RetrievalNormalizedDCG(RetrievalMetric): def __init__( self, empty_target_action: str = "neg", + k: int = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, - k: int = None, ) -> None: super().__init__( empty_target_action=empty_target_action, diff --git a/torchmetrics/retrieval/retrieval_precision.py b/torchmetrics/retrieval/retrieval_precision.py index 48e2b58088d..4cf32625a5c 100644 --- a/torchmetrics/retrieval/retrieval_precision.py +++ b/torchmetrics/retrieval/retrieval_precision.py @@ -44,6 +44,7 @@ class RetrievalPrecision(RetrievalMetric): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` + k: consider only the top k elements for each query (default: None, which considers them all) compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -55,7 +56,10 @@ class RetrievalPrecision(RetrievalMetric): dist_sync_fn: Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. default: None - k: consider only the top k elements for each query. default: None + + Raises: + ValueError: + If ``k`` parameter is not `None` or an integer larger than 0 Example: >>> from torchmetrics import RetrievalPrecision @@ -72,11 +76,11 @@ class RetrievalPrecision(RetrievalMetric): def __init__( self, empty_target_action: str = "neg", + k: int = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, - k: int = None, ) -> None: super().__init__( empty_target_action=empty_target_action, diff --git a/torchmetrics/retrieval/retrieval_recall.py b/torchmetrics/retrieval/retrieval_recall.py index 76836e9c4bd..e1304ab4841 100644 --- a/torchmetrics/retrieval/retrieval_recall.py +++ b/torchmetrics/retrieval/retrieval_recall.py @@ -44,6 +44,7 @@ class RetrievalRecall(RetrievalMetric): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` + k: consider only the top k elements for each query (default: None, which considers them all) compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -55,7 +56,10 @@ class RetrievalRecall(RetrievalMetric): dist_sync_fn: Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. default: None - k: consider only the top k elements for each query. default: None + + Raises: + ValueError: + If ``k`` parameter is not `None` or an integer larger than 0 Example: >>> from torchmetrics import RetrievalRecall @@ -72,11 +76,11 @@ class RetrievalRecall(RetrievalMetric): def __init__( self, empty_target_action: str = "neg", + k: int = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, - k: int = None, ) -> None: super().__init__( empty_target_action=empty_target_action,