diff --git a/CHANGELOG.md b/CHANGELOG.md index 92aca1e0774..3f30ab1d83e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Pearson metrics now only store 6 statistics instead of all predictions and targets ([#380](https://github.com/PyTorchLightning/metrics/pull/380)) +- Use `torch.argmax` instead of `torch.topk` when `k=1` for better performance ([#419](https://github.com/PyTorchLightning/metrics/pull/419)) + + ### Deprecated - Rename `r2score` >> `r2_score` and `kldivergence` >> `kl_divergence` in `functional` ([#371](https://github.com/PyTorchLightning/metrics/pull/371)) diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 40f81f77ccb..46648352e8f 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -94,7 +94,10 @@ def select_topk(prob_tensor: Tensor, topk: int = 1, dim: int = 1) -> Tensor: [1, 1, 0]], dtype=torch.int32) """ zeros = torch.zeros_like(prob_tensor) - topk_tensor = zeros.scatter(dim, prob_tensor.topk(k=topk, dim=dim).indices, 1.0) + if topk == 1: # argmax has better performance than topk + topk_tensor = zeros.scatter(dim, prob_tensor.argmax(dim=dim, keepdim=True), 1.0) + else: + topk_tensor = zeros.scatter(dim, prob_tensor.topk(k=topk, dim=dim).indices, 1.0) return topk_tensor.int()