From f4afaf510e71c1f4c6e7ac994949925aa8ce31b2 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 25 Apr 2022 22:51:26 +0900 Subject: [PATCH] using lru_cache (#980) --- torchmetrics/utilities/checks.py | 12 ++++++------ torchmetrics/utilities/data.py | 4 ++-- torchmetrics/utilities/distributed.py | 10 +++++----- torchmetrics/utilities/imports.py | 4 ++++ 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index cad5d73b0f9..f688a873c80 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -243,7 +243,7 @@ def _check_classification_inputs( either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` tensor, where applicable. top_k: - Number of highest probability entries for each sample to convert to 1s - relevant + Number of the highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. The default value (``None``) will be interpreted as 1 for these inputs. If this parameter is set for multi-label inputs, it will take precedence over threshold. @@ -342,7 +342,7 @@ def _input_format_classification( In binary case, targets are normally returned as ``(N,1)`` tensor, while preds are transformed into a binary tensor (elements become 1 if the probability is greater than or equal to - ``threshold`` or 0 otherwise). If ``multiclass=True``, then then both targets are preds + ``threshold`` or 0 otherwise). If ``multiclass=True``, then both targets are preds become ``(N, 2)`` tensors by a one-hot transformation; with the thresholding being applied to preds first. @@ -461,7 +461,7 @@ def _input_format_classification_one_hot( Args: num_classes: number of classes preds: either tensor with labels, tensor with probabilities/logits or multilabel tensor - target: tensor with ground true labels + target: tensor with ground-true labels threshold: float used for thresholding multilabel input multilabel: boolean flag indicating if input is multilabel @@ -503,7 +503,7 @@ def _check_retrieval_functional_inputs( target: Tensor, allow_non_binary_target: bool = False, ) -> Tuple[Tensor, Tensor]: - """Check ``preds`` and ``target`` tensors are of the same shape and of the correct dtype. + """Check ``preds`` and ``target`` tensors are of the same shape and of the correct data type. Args: preds: either tensor with scores/logits @@ -535,7 +535,7 @@ def _check_retrieval_inputs( allow_non_binary_target: bool = False, ignore_index: Optional[int] = None, ) -> Tuple[Tensor, Tensor, Tensor]: - """Check ``indexes``, ``preds`` and ``target`` tensors are of the same shape and of the correct dtype. + """Check ``indexes``, ``preds`` and ``target`` tensors are of the same shape and of the correct data type. Args: indexes: tensor with queries indexes @@ -580,7 +580,7 @@ def _check_retrieval_target_and_prediction_types( target: Tensor, allow_non_binary_target: bool = False, ) -> Tuple[Tensor, Tensor]: - """Check ``preds`` and ``target`` tensors are of the same shape and of the correct dtype. + """Check ``preds`` and ``target`` tensors are of the same shape and of the correct data type. Args: preds: either tensor with scores/logits diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 003f5a66d9c..7ed5242c2b1 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -157,10 +157,10 @@ def apply_to_collection( data: the collection to apply the function to dtype: the given function will be applied to all elements of this dtype function: the function to apply - *args: positional arguments (will be forwarded to calls of ``function``) + *args: positional arguments (will be forwarded to call of ``function``) wrong_dtype: the given function won't be applied if this type is specified and the given collections is of the :attr:`wrong_type` even if it is of type :attr`dtype` - **kwargs: keyword arguments (will be forwarded to calls of ``function``) + **kwargs: keyword arguments (will be forwarded to call of ``function``) Returns: the resulting collection diff --git a/torchmetrics/utilities/distributed.py b/torchmetrics/utilities/distributed.py index f33f864bb04..46517ac8ac3 100644 --- a/torchmetrics/utilities/distributed.py +++ b/torchmetrics/utilities/distributed.py @@ -19,11 +19,11 @@ from typing_extensions import Literal -def reduce(to_reduce: Tensor, reduction: Literal["elementwise_mean", "sum", "none", None]) -> Tensor: +def reduce(x: Tensor, reduction: Literal["elementwise_mean", "sum", "none", None]) -> Tensor: """Reduces a given tensor by a given reduction method. Args: - to_reduce: the tensor, which shall be reduced + x: the tensor, which shall be reduced reduction: a string specifying the reduction method ('elementwise_mean', 'none', 'sum') Return: @@ -33,11 +33,11 @@ def reduce(to_reduce: Tensor, reduction: Literal["elementwise_mean", "sum", "non ValueError if an invalid reduction parameter was given """ if reduction == "elementwise_mean": - return torch.mean(to_reduce) + return torch.mean(x) if reduction == "none" or reduction is None: - return to_reduce + return x if reduction == "sum": - return torch.sum(to_reduce) + return torch.sum(x) raise ValueError("Reduction parameter unknown.") diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index 29e596717d6..6598eb5ea86 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -14,6 +14,7 @@ """Import utilities.""" import operator from collections import OrderedDict # noqa: F401 +from functools import lru_cache from importlib import import_module from importlib.util import find_spec from typing import Callable, Optional @@ -22,6 +23,7 @@ from pkg_resources import DistributionNotFound, get_distribution +@lru_cache() def _package_available(package_name: str) -> bool: """Check if a package is available in your environment. @@ -40,6 +42,7 @@ def _package_available(package_name: str) -> bool: return False +@lru_cache() def _module_available(module_path: str) -> bool: """Check if a module path is available in your environment. @@ -64,6 +67,7 @@ def _module_available(module_path: str) -> bool: return True +@lru_cache() def _compare_version(package: str, op: Callable, version: str) -> Optional[bool]: """Compare package version with some requirements.