diff --git a/CHANGELOG.md b/CHANGELOG.md index 769a6b7d4ad..75f0842b15a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed `top_k` for `multiclassf1score` with one-hot encoding ([#2839](https://github.com/Lightning-AI/torchmetrics/issues/2839)) +- Fixed slow calculations of classification metrics with MPS ([#2876](https://github.com/Lightning-AI/torchmetrics/issues/2876)) + --- ## [1.6.0] - 2024-11-12 diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index a4c80f6d4a1..dc67d5a4e34 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -199,7 +199,7 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: if minlength is None: minlength = len(torch.unique(x)) - if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE and x.is_mps: + if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or x.is_mps: mesh = torch.arange(minlength, device=x.device).repeat(len(x), 1) return torch.eq(x.reshape(-1, 1), mesh).sum(dim=0)