From 322a682c920e560a13f2e2ef724acbd97f56f561 Mon Sep 17 00:00:00 2001 From: Paul Grundmann Date: Fri, 16 Jul 2021 12:09:48 +0200 Subject: [PATCH 1/9] Should fix issue #377 - --- torchmetrics/utilities/checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 405bc980859..f9648362940 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -531,7 +531,7 @@ def _check_retrieval_functional_inputs( if not preds.is_floating_point(): raise ValueError("`preds` must be a tensor of floats") - if not allow_non_binary_target and target.max() > 1 or target.min() < 0: + if not allow_non_binary_target and (target.max() > 1 or target.min() < 0): raise ValueError("`target` must contain `binary` values") return preds.float().flatten(), target.long().flatten() From 5b49af5370d38c497adbead19408c2eb780d9a12 Mon Sep 17 00:00:00 2001 From: Paul Grundmann Date: Mon, 19 Jul 2021 09:55:28 +0200 Subject: [PATCH 2/9] Add: - Test nDCG with negative relevance targets --- tests/retrieval/inputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/retrieval/inputs.py b/tests/retrieval/inputs.py index f790e8cfa50..b929040d341 100644 --- a/tests/retrieval/inputs.py +++ b/tests/retrieval/inputs.py @@ -35,7 +35,7 @@ _input_retrieval_scores_non_binary_target = Input( indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)), preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - target=torch.randint(high=4, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(low=-2, high=4, size=(NUM_BATCHES, BATCH_SIZE)), ) # with errors From d1db30c72dfe7f4d4cc21e191281924e5e3baacb Mon Sep 17 00:00:00 2001 From: Paul Grundmann Date: Tue, 20 Jul 2021 14:23:19 +0200 Subject: [PATCH 3/9] Fix: Check for non binary values for retrieval targets --- torchmetrics/utilities/checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index f9648362940..82c05870a29 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -575,7 +575,7 @@ def _check_retrieval_inputs( if target.dtype not in (torch.bool, torch.long, torch.int): raise ValueError("`target` must be a tensor of booleans or integers") - if not allow_non_binary_target and target.max() > 1 or target.min() < 0: + if not allow_non_binary_target and (target.max() > 1 or target.min() < 0): raise ValueError("`target` must contain `binary` values") return indexes.long().flatten(), preds.float().flatten(), target.long().flatten() From 3eac51ab691bb05806ffa405089973495c1aae9a Mon Sep 17 00:00:00 2001 From: Paul Grundmann Date: Mon, 26 Jul 2021 19:26:53 +0200 Subject: [PATCH 4/9] Fix: - Use the scikit-learn implementation of nDCG - Removed the test for non binary targets in test_ndcg.py and replaced the default parameters in the error test with a custom one that does not check for binary targets - set the _input_retrieval_scores_non_binary_target low to -1 to reduce the test failure rate --- tests/retrieval/helpers.py | 49 +++++++++++++++++++++++ tests/retrieval/inputs.py | 2 +- tests/retrieval/test_ndcg.py | 11 ++--- torchmetrics/functional/retrieval/ndcg.py | 14 ++++--- 4 files changed, 63 insertions(+), 13 deletions(-) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 43ed478058f..52aba60b73c 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -137,6 +137,19 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict: ] ) +_errors_test_functional_metric_parameters_with_nonbinary = dict( + argnames="preds,target,message,metric_args", + argvalues=[ + # check input shapes are consistent (func) + (_irs_mis_sz_fn.preds, _irs_mis_sz_fn.target, "`preds` and `target` must be of the same shape", {}), + # check input tensors are not empty + (_irs_empty.preds, _irs_empty.target, "`preds` and `target` must be non-empty and non-scalar tensors", {}), + # check on input dtypes + (_irs.preds.bool(), _irs.target, "`preds` must be a tensor of floats", {}), + (_irs.preds, _irs.target.float(), "`target` must be a tensor of booleans or integers", {}), + ] +) + _errors_test_functional_metric_parameters_k = dict( argnames="preds,target,message,metric_args", argvalues=[ @@ -167,6 +180,42 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict: ] ) +_errors_test_class_metric_parameters_with_nonbinary = dict( + argnames="indexes,preds,target,message,metric_args", + argvalues=[ + (None, _irs.preds, _irs.target, "`indexes` cannot be None", dict(empty_target_action="error")), + # check when input arguments are invalid + ( + _irs.indexes, _irs.preds, _irs.target, "`empty_target_action` received a wrong value `casual_argument`.", + dict(empty_target_action="casual_argument") + ), + # check input shapes are consistent + ( + _irs_mis_sz.indexes, _irs_mis_sz.preds, _irs_mis_sz.target, + "`indexes`, `preds` and `target` must be of the same shape", dict(empty_target_action="skip") + ), + # check input tensors are not empty + ( + _irs_empty.indexes, _irs_empty.preds, + _irs_empty.target, "`indexes`, `preds` and `target` must be non-empty and non-scalar tensors", + dict(empty_target_action="skip") + ), + # check on input dtypes + ( + _irs.indexes.bool(), _irs.preds, _irs.target, "`indexes` must be a tensor of long integers", + dict(empty_target_action="skip") + ), + ( + _irs.indexes, _irs.preds.bool(), _irs.target, "`preds` must be a tensor of floats", + dict(empty_target_action="skip") + ), + ( + _irs.indexes, _irs.preds, _irs.target.float(), "`target` must be a tensor of booleans or integers", + dict(empty_target_action="skip") + ) + ] +) + _errors_test_class_metric_parameters_default = dict( argnames="indexes,preds,target,message,metric_args", argvalues=[ diff --git a/tests/retrieval/inputs.py b/tests/retrieval/inputs.py index b929040d341..7ee1f58a91c 100644 --- a/tests/retrieval/inputs.py +++ b/tests/retrieval/inputs.py @@ -35,7 +35,7 @@ _input_retrieval_scores_non_binary_target = Input( indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)), preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - target=torch.randint(low=-2, high=4, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(low=-1, high=4, size=(NUM_BATCHES, BATCH_SIZE)), ) # with errors diff --git a/tests/retrieval/test_ndcg.py b/tests/retrieval/test_ndcg.py index 93a67510b04..68ebddd76e4 100644 --- a/tests/retrieval/test_ndcg.py +++ b/tests/retrieval/test_ndcg.py @@ -22,11 +22,9 @@ _concat_tests, _default_metric_class_input_arguments_with_non_binary_target, _default_metric_functional_input_arguments_with_non_binary_target, - _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_k, - _errors_test_class_metric_parameters_no_pos_target, - _errors_test_functional_metric_parameters_default, - _errors_test_functional_metric_parameters_k, + _errors_test_functional_metric_parameters_k, _errors_test_class_metric_parameters_with_nonbinary, + _errors_test_functional_metric_parameters_with_nonbinary, ) from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg from torchmetrics.retrieval.retrieval_ndcg import RetrievalNormalizedDCG @@ -114,8 +112,7 @@ def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): @pytest.mark.parametrize( **_concat_tests( - _errors_test_class_metric_parameters_default, - _errors_test_class_metric_parameters_no_pos_target, + _errors_test_class_metric_parameters_with_nonbinary, _errors_test_class_metric_parameters_k, ) ) @@ -135,7 +132,7 @@ def test_arguments_class_metric( @pytest.mark.parametrize( **_concat_tests( - _errors_test_functional_metric_parameters_default, + _errors_test_functional_metric_parameters_with_nonbinary, _errors_test_functional_metric_parameters_k, ) ) diff --git a/torchmetrics/functional/retrieval/ndcg.py b/torchmetrics/functional/retrieval/ndcg.py index 211654efd41..9dc3f0d1d52 100644 --- a/torchmetrics/functional/retrieval/ndcg.py +++ b/torchmetrics/functional/retrieval/ndcg.py @@ -21,7 +21,7 @@ def _dcg(target: Tensor) -> Tensor: denom = torch.log2(torch.arange(target.shape[-1], device=target.device) + 2.0) - return (target / denom).sum() + return (target / denom).sum(dim=-1) def retrieval_normalized_dcg(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor: @@ -55,10 +55,14 @@ def retrieval_normalized_dcg(preds: Tensor, target: Tensor, k: Optional[int] = N if not (isinstance(k, int) and k > 0): raise ValueError("`k` has to be a positive integer or None") - if not target.sum(): - return tensor(0.0, device=preds.device) - sorted_target = target[torch.argsort(preds, dim=-1, descending=True)][:k] ideal_target = torch.sort(target, descending=True)[0][:k] - return _dcg(sorted_target) / _dcg(ideal_target) + ideal_dcg = _dcg(ideal_target) + target_dcg = _dcg(sorted_target) + + all_irrelevant = ideal_dcg == 0 + target_dcg[all_irrelevant] = 0 + target_dcg[~all_irrelevant] /= ideal_dcg[~all_irrelevant] + + return target_dcg.mean() From 4fe79b7fe365f6cee739ac789208c16a227f8b9e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jul 2021 17:28:07 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/retrieval/test_ndcg.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/retrieval/test_ndcg.py b/tests/retrieval/test_ndcg.py index 68ebddd76e4..2543a31d9f0 100644 --- a/tests/retrieval/test_ndcg.py +++ b/tests/retrieval/test_ndcg.py @@ -23,7 +23,8 @@ _default_metric_class_input_arguments_with_non_binary_target, _default_metric_functional_input_arguments_with_non_binary_target, _errors_test_class_metric_parameters_k, - _errors_test_functional_metric_parameters_k, _errors_test_class_metric_parameters_with_nonbinary, + _errors_test_class_metric_parameters_with_nonbinary, + _errors_test_functional_metric_parameters_k, _errors_test_functional_metric_parameters_with_nonbinary, ) from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg From ce34b68c149032928474e1b1b2fa4a7d7f23fb4b Mon Sep 17 00:00:00 2001 From: Paul Grundmann Date: Tue, 27 Jul 2021 11:06:03 +0200 Subject: [PATCH 6/9] Fix: - removed unused imports in ndcg.py --- torchmetrics/functional/retrieval/ndcg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/retrieval/ndcg.py b/torchmetrics/functional/retrieval/ndcg.py index 9dc3f0d1d52..813b33b746b 100644 --- a/torchmetrics/functional/retrieval/ndcg.py +++ b/torchmetrics/functional/retrieval/ndcg.py @@ -14,7 +14,7 @@ from typing import Optional import torch -from torch import Tensor, tensor +from torch import Tensor from torchmetrics.utilities.checks import _check_retrieval_functional_inputs From 92de263bb42348c06038f0a6d378dca5b4457387 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 28 Jul 2021 14:56:30 +0200 Subject: [PATCH 7/9] changelog --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c022003bda7..0b8bff9be54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,14 +13,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support in `nDCG` metric for target with values larger than 1 ([#343](https://github.com/PyTorchLightning/metrics/issues/343)) + - Added Word error rate (WER) ([#52](https://github.com/PyTorchLightning/metrics/issues/52)) + - Added Symmetric Mean Absolute Percentage error (SMAPE) ([#375](https://github.com/PyTorchLightning/metrics/issues/375)) - Allowed passing labels in (n_samples, n_classes) to `AveragePrecision` ([#386](https://github.com/PyTorchLightning/metrics/issues/386)) +- Added support for negative targets in `nDCG` metric ([#378](https://github.com/PyTorchLightning/metrics/pull/378)) + + ### Changed - Moved `psnr` and `ssim` from `functional.regression.*` to `functional.image.*` ([#382](https://github.com/PyTorchLightning/metrics/pull/382)) From 29411f20f35e9c86e3393cccb0834ac268dd820d Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 28 Jul 2021 14:58:33 +0200 Subject: [PATCH 8/9] more stable tests --- tests/retrieval/inputs.py | 6 +++--- torchmetrics/functional/retrieval/ndcg.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/retrieval/inputs.py b/tests/retrieval/inputs.py index 7ee1f58a91c..4a45f037d1d 100644 --- a/tests/retrieval/inputs.py +++ b/tests/retrieval/inputs.py @@ -33,9 +33,9 @@ ) _input_retrieval_scores_non_binary_target = Input( - indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)), - preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - target=torch.randint(low=-1, high=4, size=(NUM_BATCHES, BATCH_SIZE)), + indexes=torch.randint(high=10, size=(NUM_BATCHES, 2*BATCH_SIZE)), + preds=torch.rand(NUM_BATCHES, 2*BATCH_SIZE), + target=torch.randint(low=-1, high=4, size=(NUM_BATCHES, 2*BATCH_SIZE)), ) # with errors diff --git a/torchmetrics/functional/retrieval/ndcg.py b/torchmetrics/functional/retrieval/ndcg.py index 813b33b746b..65ac97247ef 100644 --- a/torchmetrics/functional/retrieval/ndcg.py +++ b/torchmetrics/functional/retrieval/ndcg.py @@ -61,6 +61,7 @@ def retrieval_normalized_dcg(preds: Tensor, target: Tensor, k: Optional[int] = N ideal_dcg = _dcg(ideal_target) target_dcg = _dcg(sorted_target) + # filter undefined scores all_irrelevant = ideal_dcg == 0 target_dcg[all_irrelevant] = 0 target_dcg[~all_irrelevant] /= ideal_dcg[~all_irrelevant] From c29d8ea4e888a3a00595c62463f3ea8a473210f9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Jul 2021 12:59:12 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/retrieval/inputs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/retrieval/inputs.py b/tests/retrieval/inputs.py index 4a45f037d1d..4a2f93cd7fc 100644 --- a/tests/retrieval/inputs.py +++ b/tests/retrieval/inputs.py @@ -33,9 +33,9 @@ ) _input_retrieval_scores_non_binary_target = Input( - indexes=torch.randint(high=10, size=(NUM_BATCHES, 2*BATCH_SIZE)), - preds=torch.rand(NUM_BATCHES, 2*BATCH_SIZE), - target=torch.randint(low=-1, high=4, size=(NUM_BATCHES, 2*BATCH_SIZE)), + indexes=torch.randint(high=10, size=(NUM_BATCHES, 2 * BATCH_SIZE)), + preds=torch.rand(NUM_BATCHES, 2 * BATCH_SIZE), + target=torch.randint(low=-1, high=4, size=(NUM_BATCHES, 2 * BATCH_SIZE)), ) # with errors