Skip to content

Commit

Permalink
Merge branch 'master' into refactoring/remove_deprecated_compute_on_s…
Browse files Browse the repository at this point in the history
…tep_image
  • Loading branch information
SkafteNicki authored Apr 25, 2022
2 parents 0775e6c + f4afaf5 commit 3e3224a
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 13 deletions.
12 changes: 6 additions & 6 deletions torchmetrics/utilities/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def _check_classification_inputs(
either from the shape of inputs, or the maximum label in the ``target`` and ``preds``
tensor, where applicable.
top_k:
Number of highest probability entries for each sample to convert to 1s - relevant
Number of the highest probability entries for each sample to convert to 1s - relevant
only for inputs with probability predictions. The default value (``None``) will be
interpreted as 1 for these inputs. If this parameter is set for multi-label inputs,
it will take precedence over threshold.
Expand Down Expand Up @@ -342,7 +342,7 @@ def _input_format_classification(
In binary case, targets are normally returned as ``(N,1)`` tensor, while preds are transformed
into a binary tensor (elements become 1 if the probability is greater than or equal to
``threshold`` or 0 otherwise). If ``multiclass=True``, then then both targets are preds
``threshold`` or 0 otherwise). If ``multiclass=True``, then both targets are preds
become ``(N, 2)`` tensors by a one-hot transformation; with the thresholding being applied to
preds first.
Expand Down Expand Up @@ -461,7 +461,7 @@ def _input_format_classification_one_hot(
Args:
num_classes: number of classes
preds: either tensor with labels, tensor with probabilities/logits or multilabel tensor
target: tensor with ground true labels
target: tensor with ground-true labels
threshold: float used for thresholding multilabel input
multilabel: boolean flag indicating if input is multilabel
Expand Down Expand Up @@ -503,7 +503,7 @@ def _check_retrieval_functional_inputs(
target: Tensor,
allow_non_binary_target: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Check ``preds`` and ``target`` tensors are of the same shape and of the correct dtype.
"""Check ``preds`` and ``target`` tensors are of the same shape and of the correct data type.
Args:
preds: either tensor with scores/logits
Expand Down Expand Up @@ -535,7 +535,7 @@ def _check_retrieval_inputs(
allow_non_binary_target: bool = False,
ignore_index: Optional[int] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Check ``indexes``, ``preds`` and ``target`` tensors are of the same shape and of the correct dtype.
"""Check ``indexes``, ``preds`` and ``target`` tensors are of the same shape and of the correct data type.
Args:
indexes: tensor with queries indexes
Expand Down Expand Up @@ -580,7 +580,7 @@ def _check_retrieval_target_and_prediction_types(
target: Tensor,
allow_non_binary_target: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Check ``preds`` and ``target`` tensors are of the same shape and of the correct dtype.
"""Check ``preds`` and ``target`` tensors are of the same shape and of the correct data type.
Args:
preds: either tensor with scores/logits
Expand Down
4 changes: 2 additions & 2 deletions torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,10 @@ def apply_to_collection(
data: the collection to apply the function to
dtype: the given function will be applied to all elements of this dtype
function: the function to apply
*args: positional arguments (will be forwarded to calls of ``function``)
*args: positional arguments (will be forwarded to call of ``function``)
wrong_dtype: the given function won't be applied if this type is specified and the given collections is of
the :attr:`wrong_type` even if it is of type :attr`dtype`
**kwargs: keyword arguments (will be forwarded to calls of ``function``)
**kwargs: keyword arguments (will be forwarded to call of ``function``)
Returns:
the resulting collection
Expand Down
10 changes: 5 additions & 5 deletions torchmetrics/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from typing_extensions import Literal


def reduce(to_reduce: Tensor, reduction: Literal["elementwise_mean", "sum", "none", None]) -> Tensor:
def reduce(x: Tensor, reduction: Literal["elementwise_mean", "sum", "none", None]) -> Tensor:
"""Reduces a given tensor by a given reduction method.
Args:
to_reduce: the tensor, which shall be reduced
x: the tensor, which shall be reduced
reduction: a string specifying the reduction method ('elementwise_mean', 'none', 'sum')
Return:
Expand All @@ -33,11 +33,11 @@ def reduce(to_reduce: Tensor, reduction: Literal["elementwise_mean", "sum", "non
ValueError if an invalid reduction parameter was given
"""
if reduction == "elementwise_mean":
return torch.mean(to_reduce)
return torch.mean(x)
if reduction == "none" or reduction is None:
return to_reduce
return x
if reduction == "sum":
return torch.sum(to_reduce)
return torch.sum(x)
raise ValueError("Reduction parameter unknown.")


Expand Down
4 changes: 4 additions & 0 deletions torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Import utilities."""
import operator
from collections import OrderedDict # noqa: F401
from functools import lru_cache
from importlib import import_module
from importlib.util import find_spec
from typing import Callable, Optional
Expand All @@ -22,6 +23,7 @@
from pkg_resources import DistributionNotFound, get_distribution


@lru_cache()
def _package_available(package_name: str) -> bool:
"""Check if a package is available in your environment.
Expand All @@ -40,6 +42,7 @@ def _package_available(package_name: str) -> bool:
return False


@lru_cache()
def _module_available(module_path: str) -> bool:
"""Check if a module path is available in your environment.
Expand All @@ -64,6 +67,7 @@ def _module_available(module_path: str) -> bool:
return True


@lru_cache()
def _compare_version(package: str, op: Callable, version: str) -> Optional[bool]:
"""Compare package version with some requirements.
Expand Down

0 comments on commit 3e3224a

Please sign in to comment.