Skip to content

Commit

Permalink
Implemented R-Precision for IR (#577)
Browse files Browse the repository at this point in the history
* implemented R-Precision for IR

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* changelog

* update changelog

* update changelog

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: SkafteNicki <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
5 people authored Oct 22, 2021
1 parent 40007e9 commit 5bbfac5
Show file tree
Hide file tree
Showing 11 changed files with 279 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Short Term Objective Intelligibility (`STOI`) ([#353](https://github.com/PyTorchLightning/metrics/issues/353))


- Added `RetrievalRPrecision` metric to retrieval package ([#577](https://github.com/PyTorchLightning/metrics/pull/577/))


- Added `RetrievalHitRate` metric to retrieval package ([#576](https://github.com/PyTorchLightning/metrics/pull/576))


Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
.. _Fall-out: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Fall-out
.. _Normalized Discounted Cumulative Gain: https://en.wikipedia.org/wiki/Discounted_cumulative_gain
.. _IR Precision: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Precision
.. _IR R-Precision: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#R-precision
.. _IR Recall: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Recall
.. _Accuracy: https://en.wikipedia.org/wiki/Accuracy_and_precision
.. _SMAPE: https://en.wikipedia.org/wiki/Symmetric_mean_absolute_percentage_error
Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,13 @@ retrieval_precision [func]
:noindex:


retrieval_r_precision [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.retrieval_r_precision
:noindex:


retrieval_recall [func]
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,13 @@ RetrievalPrecision
:noindex:


RetrievalRPrecision
~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.RetrievalRPrecision
:noindex:


RetrievalRecall
~~~~~~~~~~~~~~~

Expand Down
136 changes: 136 additions & 0 deletions tests/retrieval/test_r_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
from torch import Tensor

from tests.helpers import seed_all
from tests.retrieval.helpers import (
RetrievalMetricTester,
_concat_tests,
_default_metric_class_input_arguments,
_default_metric_functional_input_arguments,
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_no_pos_target,
_errors_test_functional_metric_parameters_default,
)
from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision
from torchmetrics.retrieval.retrieval_r_precision import RetrievalRPrecision

seed_all(42)


def _r_precision(target: np.ndarray, preds: np.ndarray):
"""Didn't find a reliable implementation of R-Precision in Information Retrieval, so, reimplementing here.
A good explanation can be found
`here <https://web.stanford.edu/class/cs276/handouts/EvaluationNew-handout-1-per.pdf>_`.
"""
assert target.shape == preds.shape
assert len(target.shape) == 1 # works only with single dimension inputs

if target.sum() > 0:
order_indexes = np.argsort(preds, axis=0)[::-1]
relevant = np.sum(target[order_indexes][: target.sum()])
return relevant * 1.0 / target.sum()
return np.NaN


class TestRPrecision(RetrievalMetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_class_metric(
self,
ddp: bool,
indexes: Tensor,
preds: Tensor,
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
):
metric_args = {"empty_target_action": empty_target_action}

self.run_class_metric_test(
ddp=ddp,
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalRPrecision,
sk_metric=_r_precision,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
)

@pytest.mark.parametrize(**_default_metric_functional_input_arguments)
def test_functional_metric(self, preds: Tensor, target: Tensor):
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=retrieval_r_precision,
sk_metric=_r_precision,
metric_args={},
)

@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor):
self.run_precision_test_cpu(
indexes=indexes,
preds=preds,
target=target,
metric_module=RetrievalRPrecision,
metric_functional=retrieval_r_precision,
)

@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor):
self.run_precision_test_gpu(
indexes=indexes,
preds=preds,
target=target,
metric_module=RetrievalRPrecision,
metric_functional=retrieval_r_precision,
)

@pytest.mark.parametrize(
**_concat_tests(
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_no_pos_target,
)
)
def test_arguments_class_metric(
self, indexes: Tensor, preds: Tensor, target: Tensor, message: str, metric_args: dict
):
self.run_metric_class_arguments_test(
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalRPrecision,
message=message,
metric_args=metric_args,
exception_type=ValueError,
kwargs_update={},
)

@pytest.mark.parametrize(**_errors_test_functional_metric_parameters_default)
def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, message: str, metric_args: dict):
self.run_functional_metric_arguments_test(
preds=preds,
target=target,
metric_functional=retrieval_r_precision,
message=message,
exception_type=ValueError,
kwargs_update=metric_args,
)
2 changes: 2 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
RetrievalNormalizedDCG,
RetrievalPrecision,
RetrievalRecall,
RetrievalRPrecision,
)
from torchmetrics.text import WER, BERTScore, BLEUScore, ROUGEScore, SacreBLEUScore # noqa: E402
from torchmetrics.wrappers import BootStrapper, MetricTracker, MultioutputWrapper # noqa: E402
Expand Down Expand Up @@ -123,6 +124,7 @@
"RetrievalNormalizedDCG",
"RetrievalPrecision",
"RetrievalRecall",
"RetrievalRPrecision",
"ROC",
"ROUGEScore",
"SacreBLEUScore",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
from torchmetrics.functional.retrieval.precision import retrieval_precision
from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision
from torchmetrics.functional.retrieval.recall import retrieval_recall
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank
from torchmetrics.functional.self_supervised import embedding_similarity
Expand Down Expand Up @@ -118,6 +119,7 @@
"retrieval_hit_rate",
"retrieval_normalized_dcg",
"retrieval_precision",
"retrieval_r_precision",
"retrieval_recall",
"retrieval_reciprocal_rank",
"roc",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate # noqa: F401
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg # noqa: F401
from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401
from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision # noqa: F401
from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
49 changes: 49 additions & 0 deletions torchmetrics/functional/retrieval/r_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor, tensor

from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def retrieval_r_precision(preds: Tensor, target: Tensor) -> Tensor:
"""Computes the r-precision metric (for information retrieval). R-Precision is the fraction of relevant
documents among all the top ``k`` retrieved documents where ``k`` is equal to the total number of relevant
documents.
``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`,
otherwise an error is raised. If you want to measure Precision@K, ``k`` must be a positive integer.
Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant or not.
Returns:
a single-value tensor with the r-precision of the predictions ``preds`` w.r.t. the labels ``target``.
Example:
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> retrieval_r_precision(preds, target)
tensor(0.5000)
"""
preds, target = _check_retrieval_functional_inputs(preds, target)

relevant_number = target.sum()
if not relevant_number:
return tensor(0.0, device=preds.device)

relevant = target[torch.argsort(preds, dim=-1, descending=True)][:relevant_number].sum().float()
return relevant / relevant_number
1 change: 1 addition & 0 deletions torchmetrics/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
from torchmetrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401
from torchmetrics.retrieval.retrieval_ndcg import RetrievalNormalizedDCG # noqa: F401
from torchmetrics.retrieval.retrieval_precision import RetrievalPrecision # noqa: F401
from torchmetrics.retrieval.retrieval_r_precision import RetrievalRPrecision # noqa: F401
from torchmetrics.retrieval.retrieval_recall import RetrievalRecall # noqa: F401
70 changes: 70 additions & 0 deletions torchmetrics/retrieval/retrieval_r_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch import Tensor, tensor

from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision
from torchmetrics.retrieval.retrieval_metric import RetrievalMetric


class RetrievalRPrecision(RetrievalMetric):
"""Computes `IR R-Precision`_.
Works with binary target data. Accepts float predictions from a model output.
Forward accepts:
- ``preds`` (float tensor): ``(N, ...)``
- ``target`` (long or bool tensor): ``(N, ...)``
- ``indexes`` (long tensor): ``(N, ...)``
``indexes``, ``preds`` and ``target`` must have the same dimension.
``indexes`` indicate to which query a prediction belongs.
Predictions will be first grouped by ``indexes`` and then `R-Precision` will be computed as the mean
of the `R-Precision` over each query.
Args:
empty_target_action:
Specify what to do with queries that do not have at least a positive ``target``. Choose from:
- ``'neg'``: those queries count as ``0.0`` (default)
- ``'pos'``: those queries count as ``1.0``
- ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects
the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
Example:
>>> from torchmetrics import RetrievalRPrecision
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> p2 = RetrievalRPrecision()
>>> p2(preds, target, indexes=indexes)
tensor(0.7500)
"""

higher_is_better = True

def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
return retrieval_r_precision(preds, target)

0 comments on commit 5bbfac5

Please sign in to comment.