From 0fa7da76fe685888d64fe5eac6c4a36cfa21a9c7 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 2 Aug 2021 17:13:03 +0200 Subject: [PATCH 1/3] argmax for k=1 --- torchmetrics/utilities/data.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 9030af992c3..78287ebc9c1 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -97,7 +97,10 @@ def select_topk(prob_tensor: Tensor, topk: int = 1, dim: int = 1) -> Tensor: [1, 1, 0]], dtype=torch.int32) """ zeros = torch.zeros_like(prob_tensor) - topk_tensor = zeros.scatter(dim, prob_tensor.topk(k=topk, dim=dim).indices, 1.0) + if topk==1: # argmax has better performance than topk + topk_tensor = zeros.scatter(dim, prob_tensor.argmax(dim=dim, keepdim=True), 1.0) + else: + topk_tensor = zeros.scatter(dim, prob_tensor.topk(k=topk, dim=dim).indices, 1.0) return topk_tensor.int() From d17680fab0f6c4cdf839c819b5234e65c6f5361c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Aug 2021 07:53:33 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/utilities/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 33ea2765e2a..339b485276b 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -97,7 +97,7 @@ def select_topk(prob_tensor: Tensor, topk: int = 1, dim: int = 1) -> Tensor: [1, 1, 0]], dtype=torch.int32) """ zeros = torch.zeros_like(prob_tensor) - if topk==1: # argmax has better performance than topk + if topk == 1: # argmax has better performance than topk topk_tensor = zeros.scatter(dim, prob_tensor.argmax(dim=dim, keepdim=True), 1.0) else: topk_tensor = zeros.scatter(dim, prob_tensor.topk(k=topk, dim=dim).indices, 1.0) From c7cd7994d4ab5fd80f23b69e464d6ece2d810082 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 3 Aug 2021 09:54:45 +0200 Subject: [PATCH 3/3] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index cc59c72b828..939167ebbcc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Pearson metrics now only store 6 statistics instead of all predictions and targets ([#380](https://github.com/PyTorchLightning/metrics/pull/380)) +- Use `torch.argmax` instead of `torch.topk` when `k=1` for better performance ([#419](https://github.com/PyTorchLightning/metrics/pull/419)) + + ### Deprecated - Rename `r2score` >> `r2_score` and `kldivergence` >> `kl_divergence` in `functional` ([#371](https://github.com/PyTorchLightning/metrics/pull/371))