-
Notifications
You must be signed in to change notification settings - Fork 422
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implemented R-Precision for IR (#577)
* implemented R-Precision for IR * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * changelog * update changelog * update changelog Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: SkafteNicki <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
- Loading branch information
1 parent
40007e9
commit 5bbfac5
Showing
11 changed files
with
279 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import numpy as np | ||
import pytest | ||
from torch import Tensor | ||
|
||
from tests.helpers import seed_all | ||
from tests.retrieval.helpers import ( | ||
RetrievalMetricTester, | ||
_concat_tests, | ||
_default_metric_class_input_arguments, | ||
_default_metric_functional_input_arguments, | ||
_errors_test_class_metric_parameters_default, | ||
_errors_test_class_metric_parameters_no_pos_target, | ||
_errors_test_functional_metric_parameters_default, | ||
) | ||
from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision | ||
from torchmetrics.retrieval.retrieval_r_precision import RetrievalRPrecision | ||
|
||
seed_all(42) | ||
|
||
|
||
def _r_precision(target: np.ndarray, preds: np.ndarray): | ||
"""Didn't find a reliable implementation of R-Precision in Information Retrieval, so, reimplementing here. | ||
A good explanation can be found | ||
`here <https://web.stanford.edu/class/cs276/handouts/EvaluationNew-handout-1-per.pdf>_`. | ||
""" | ||
assert target.shape == preds.shape | ||
assert len(target.shape) == 1 # works only with single dimension inputs | ||
|
||
if target.sum() > 0: | ||
order_indexes = np.argsort(preds, axis=0)[::-1] | ||
relevant = np.sum(target[order_indexes][: target.sum()]) | ||
return relevant * 1.0 / target.sum() | ||
return np.NaN | ||
|
||
|
||
class TestRPrecision(RetrievalMetricTester): | ||
@pytest.mark.parametrize("ddp", [True, False]) | ||
@pytest.mark.parametrize("dist_sync_on_step", [True, False]) | ||
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) | ||
@pytest.mark.parametrize(**_default_metric_class_input_arguments) | ||
def test_class_metric( | ||
self, | ||
ddp: bool, | ||
indexes: Tensor, | ||
preds: Tensor, | ||
target: Tensor, | ||
dist_sync_on_step: bool, | ||
empty_target_action: str, | ||
): | ||
metric_args = {"empty_target_action": empty_target_action} | ||
|
||
self.run_class_metric_test( | ||
ddp=ddp, | ||
indexes=indexes, | ||
preds=preds, | ||
target=target, | ||
metric_class=RetrievalRPrecision, | ||
sk_metric=_r_precision, | ||
dist_sync_on_step=dist_sync_on_step, | ||
metric_args=metric_args, | ||
) | ||
|
||
@pytest.mark.parametrize(**_default_metric_functional_input_arguments) | ||
def test_functional_metric(self, preds: Tensor, target: Tensor): | ||
self.run_functional_metric_test( | ||
preds=preds, | ||
target=target, | ||
metric_functional=retrieval_r_precision, | ||
sk_metric=_r_precision, | ||
metric_args={}, | ||
) | ||
|
||
@pytest.mark.parametrize(**_default_metric_class_input_arguments) | ||
def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor): | ||
self.run_precision_test_cpu( | ||
indexes=indexes, | ||
preds=preds, | ||
target=target, | ||
metric_module=RetrievalRPrecision, | ||
metric_functional=retrieval_r_precision, | ||
) | ||
|
||
@pytest.mark.parametrize(**_default_metric_class_input_arguments) | ||
def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): | ||
self.run_precision_test_gpu( | ||
indexes=indexes, | ||
preds=preds, | ||
target=target, | ||
metric_module=RetrievalRPrecision, | ||
metric_functional=retrieval_r_precision, | ||
) | ||
|
||
@pytest.mark.parametrize( | ||
**_concat_tests( | ||
_errors_test_class_metric_parameters_default, | ||
_errors_test_class_metric_parameters_no_pos_target, | ||
) | ||
) | ||
def test_arguments_class_metric( | ||
self, indexes: Tensor, preds: Tensor, target: Tensor, message: str, metric_args: dict | ||
): | ||
self.run_metric_class_arguments_test( | ||
indexes=indexes, | ||
preds=preds, | ||
target=target, | ||
metric_class=RetrievalRPrecision, | ||
message=message, | ||
metric_args=metric_args, | ||
exception_type=ValueError, | ||
kwargs_update={}, | ||
) | ||
|
||
@pytest.mark.parametrize(**_errors_test_functional_metric_parameters_default) | ||
def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, message: str, metric_args: dict): | ||
self.run_functional_metric_arguments_test( | ||
preds=preds, | ||
target=target, | ||
metric_functional=retrieval_r_precision, | ||
message=message, | ||
exception_type=ValueError, | ||
kwargs_update=metric_args, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import torch | ||
from torch import Tensor, tensor | ||
|
||
from torchmetrics.utilities.checks import _check_retrieval_functional_inputs | ||
|
||
|
||
def retrieval_r_precision(preds: Tensor, target: Tensor) -> Tensor: | ||
"""Computes the r-precision metric (for information retrieval). R-Precision is the fraction of relevant | ||
documents among all the top ``k`` retrieved documents where ``k`` is equal to the total number of relevant | ||
documents. | ||
``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``, | ||
``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`, | ||
otherwise an error is raised. If you want to measure Precision@K, ``k`` must be a positive integer. | ||
Args: | ||
preds: estimated probabilities of each document to be relevant. | ||
target: ground truth about each document being relevant or not. | ||
Returns: | ||
a single-value tensor with the r-precision of the predictions ``preds`` w.r.t. the labels ``target``. | ||
Example: | ||
>>> preds = tensor([0.2, 0.3, 0.5]) | ||
>>> target = tensor([True, False, True]) | ||
>>> retrieval_r_precision(preds, target) | ||
tensor(0.5000) | ||
""" | ||
preds, target = _check_retrieval_functional_inputs(preds, target) | ||
|
||
relevant_number = target.sum() | ||
if not relevant_number: | ||
return tensor(0.0, device=preds.device) | ||
|
||
relevant = target[torch.argsort(preds, dim=-1, descending=True)][:relevant_number].sum().float() | ||
return relevant / relevant_number |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from torch import Tensor, tensor | ||
|
||
from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision | ||
from torchmetrics.retrieval.retrieval_metric import RetrievalMetric | ||
|
||
|
||
class RetrievalRPrecision(RetrievalMetric): | ||
"""Computes `IR R-Precision`_. | ||
Works with binary target data. Accepts float predictions from a model output. | ||
Forward accepts: | ||
- ``preds`` (float tensor): ``(N, ...)`` | ||
- ``target`` (long or bool tensor): ``(N, ...)`` | ||
- ``indexes`` (long tensor): ``(N, ...)`` | ||
``indexes``, ``preds`` and ``target`` must have the same dimension. | ||
``indexes`` indicate to which query a prediction belongs. | ||
Predictions will be first grouped by ``indexes`` and then `R-Precision` will be computed as the mean | ||
of the `R-Precision` over each query. | ||
Args: | ||
empty_target_action: | ||
Specify what to do with queries that do not have at least a positive ``target``. Choose from: | ||
- ``'neg'``: those queries count as ``0.0`` (default) | ||
- ``'pos'``: those queries count as ``1.0`` | ||
- ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned | ||
- ``'error'``: raise a ``ValueError`` | ||
compute_on_step: | ||
Forward only calls ``update()`` and return None if this is set to False. default: True | ||
dist_sync_on_step: | ||
Synchronize metric state across processes at each ``forward()`` | ||
before returning the value at the step. default: False | ||
process_group: | ||
Specify the process group on which synchronization is called. default: None (which selects | ||
the entire world) | ||
dist_sync_fn: | ||
Callback that performs the allgather operation on the metric state. When `None`, DDP | ||
will be used to perform the allgather. default: None | ||
Example: | ||
>>> from torchmetrics import RetrievalRPrecision | ||
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) | ||
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) | ||
>>> target = tensor([False, False, True, False, True, False, True]) | ||
>>> p2 = RetrievalRPrecision() | ||
>>> p2(preds, target, indexes=indexes) | ||
tensor(0.7500) | ||
""" | ||
|
||
higher_is_better = True | ||
|
||
def _metric(self, preds: Tensor, target: Tensor) -> Tensor: | ||
return retrieval_r_precision(preds, target) |