diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d80b98376e..fcf702137b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added pre-gather reduction in the case of `dist_reduce_fx="cat"` to reduce communication cost ([#217](https://github.com/PyTorchLightning/metrics/pull/217)) +- Added support for unnormalized scores (e.g. logits) in `Accuracy`, `Precision`, `Recall`, `FBeta`, `F1`, `StatScore`, `Hamming`, `ConfusionMatrix` metrics ([#200](https://github.com/PyTorchLightning/metrics/pull/200)) + ### Changed diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 9a667b6c315..9889d7a6f46 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -34,10 +34,10 @@ into these categories (``N`` stands for the batch size and ``C`` for number of c "Binary", "(N,)", "``float``", "(N,)", "``binary``\*" "Multi-class", "(N,)", "``int``", "(N,)", "``int``" - "Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``" + "Multi-class with logits or probabilities", "(N, C)", "``float``", "(N,)", "``int``" "Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*" "Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``" - "Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``" + "Multi-dimensional multi-class with logits or probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``" .. note:: All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so diff --git a/tests/classification/inputs.py b/tests/classification/inputs.py index c6ce83a4069..676083fb7b4 100644 --- a/tests/classification/inputs.py +++ b/tests/classification/inputs.py @@ -28,6 +28,10 @@ target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) ) +_input_binary_logits = Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) +) + _input_multilabel_prob = Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) @@ -38,11 +42,17 @@ target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)) ) +_input_multilabel_logits = Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) +) + _input_multilabel = Input( preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) ) + _input_multilabel_multidim = Input( preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)) @@ -54,13 +64,17 @@ _input_multilabel_no_match = Input(preds=__temp_preds, target=__temp_target) -__mc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) -__mc_prob_preds = __mc_prob_preds / __mc_prob_preds.sum(dim=2, keepdim=True) +__mc_prob_logits = torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) +__mc_prob_preds = __mc_prob_logits.abs() / __mc_prob_logits.abs().sum(dim=2, keepdim=True) _input_multiclass_prob = Input( preds=__mc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) ) +_input_multiclass_logits = Input( + preds=__mc_prob_logits, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) +) + _input_multiclass = Input( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py index 00e4fa32a35..c0569caeb00 100644 --- a/tests/classification/test_accuracy.py +++ b/tests/classification/test_accuracy.py @@ -18,12 +18,14 @@ from sklearn.metrics import accuracy_score as sk_accuracy from torch import tensor -from tests.classification.inputs import _input_binary, _input_binary_prob +from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob from tests.classification.inputs import _input_multiclass as _input_mcls +from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob from tests.classification.inputs import _input_multilabel as _input_mlb +from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits from tests.classification.inputs import _input_multilabel_multidim as _input_mlmd from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob @@ -55,13 +57,16 @@ def _sk_accuracy(preds, target, subset_accuracy): @pytest.mark.parametrize( "preds, target, subset_accuracy", [ + (_input_binary_logits.preds, _input_binary_logits.target, False), (_input_binary_prob.preds, _input_binary_prob.target, False), (_input_binary.preds, _input_binary.target, False), (_input_mlb_prob.preds, _input_mlb_prob.target, True), + (_input_mlb_logits.preds, _input_mlb_logits.target, False), (_input_mlb_prob.preds, _input_mlb_prob.target, False), (_input_mlb.preds, _input_mlb.target, True), (_input_mlb.preds, _input_mlb.target, False), (_input_mcls_prob.preds, _input_mcls_prob.target, False), + (_input_mcls_logits.preds, _input_mcls_logits.target, False), (_input_mcls.preds, _input_mcls.target, False), (_input_mdmc_prob.preds, _input_mdmc_prob.target, False), (_input_mdmc_prob.preds, _input_mdmc_prob.target, True), diff --git a/tests/classification/test_confusion_matrix.py b/tests/classification/test_confusion_matrix.py index 2f43e5ac3e2..44f599eedcd 100644 --- a/tests/classification/test_confusion_matrix.py +++ b/tests/classification/test_confusion_matrix.py @@ -19,12 +19,14 @@ from sklearn.metrics import confusion_matrix as sk_confusion_matrix from sklearn.metrics import multilabel_confusion_matrix as sk_multilabel_confusion_matrix -from tests.classification.inputs import _input_binary, _input_binary_prob +from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob from tests.classification.inputs import _input_multiclass as _input_mcls +from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob from tests.classification.inputs import _input_multilabel as _input_mlb +from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester @@ -112,10 +114,13 @@ def _sk_cm_multidim_multiclass(preds, target, normalize=None): @pytest.mark.parametrize( "preds, target, sk_metric, num_classes, multilabel", [(_input_binary_prob.preds, _input_binary_prob.target, _sk_cm_binary_prob, 2, False), + (_input_binary_logits.preds, _input_binary_logits.target, _sk_cm_binary_prob, 2, False), (_input_binary.preds, _input_binary.target, _sk_cm_binary, 2, False), (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cm_multilabel_prob, NUM_CLASSES, True), + (_input_mlb_logits.preds, _input_mlb_logits.target, _sk_cm_multilabel_prob, NUM_CLASSES, True), (_input_mlb.preds, _input_mlb.target, _sk_cm_multilabel, NUM_CLASSES, True), (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cm_multiclass_prob, NUM_CLASSES, False), + (_input_mcls_logits.preds, _input_mcls_logits.target, _sk_cm_multiclass_prob, NUM_CLASSES, False), (_input_mcls.preds, _input_mcls.target, _sk_cm_multiclass, NUM_CLASSES, False), (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cm_multidim_multiclass_prob, NUM_CLASSES, False), (_input_mdmc.preds, _input_mdmc.target, _sk_cm_multidim_multiclass, NUM_CLASSES, False)] diff --git a/tests/classification/test_f_beta.py b/tests/classification/test_f_beta.py index b00758c32ba..693c445dd4b 100644 --- a/tests/classification/test_f_beta.py +++ b/tests/classification/test_f_beta.py @@ -20,12 +20,14 @@ from sklearn.metrics import f1_score, fbeta_score from torch import Tensor -from tests.classification.inputs import _input_binary, _input_binary_prob +from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob from tests.classification.inputs import _input_multiclass as _input_mcls +from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob from tests.classification.inputs import _input_multilabel as _input_mlb +from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester @@ -176,10 +178,13 @@ def test_no_support(metric_class, metric_fn): @pytest.mark.parametrize( "preds, target, num_classes, multiclass, mdmc_average, sk_wrapper", [ + (_input_binary_logits.preds, _input_binary_logits.target, 1, None, None, _sk_fbeta_f1), (_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_fbeta_f1), (_input_binary.preds, _input_binary.target, 1, False, None, _sk_fbeta_f1), + (_input_mlb_logits.preds, _input_mlb_logits.target, NUM_CLASSES, None, None, _sk_fbeta_f1), (_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_fbeta_f1), (_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_fbeta_f1), + (_input_mcls_logits.preds, _input_mcls_logits.target, NUM_CLASSES, None, None, _sk_fbeta_f1), (_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_fbeta_f1), (_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_fbeta_f1), (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_fbeta_f1_multidim_multiclass), diff --git a/tests/classification/test_hamming_distance.py b/tests/classification/test_hamming_distance.py index 5b8e7000716..8f629e7d6f3 100644 --- a/tests/classification/test_hamming_distance.py +++ b/tests/classification/test_hamming_distance.py @@ -14,12 +14,14 @@ import pytest from sklearn.metrics import hamming_loss as sk_hamming_loss -from tests.classification.inputs import _input_binary, _input_binary_prob +from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob from tests.classification.inputs import _input_multiclass as _input_mcls +from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob from tests.classification.inputs import _input_multilabel as _input_mlb +from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits from tests.classification.inputs import _input_multilabel_multidim as _input_mlmd from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob @@ -43,10 +45,13 @@ def _sk_hamming_loss(preds, target): @pytest.mark.parametrize( "preds, target", [ + (_input_binary_logits.preds, _input_binary_logits.target), (_input_binary_prob.preds, _input_binary_prob.target), (_input_binary.preds, _input_binary.target), + (_input_mlb_logits.preds, _input_mlb_logits.target), (_input_mlb_prob.preds, _input_mlb_prob.target), (_input_mlb.preds, _input_mlb.target), + (_input_mcls_logits.preds, _input_mcls_logits.target), (_input_mcls_prob.preds, _input_mcls_prob.target), (_input_mcls.preds, _input_mcls.target), (_input_mdmc_prob.preds, _input_mdmc_prob.target), diff --git a/tests/classification/test_inputs.py b/tests/classification/test_inputs.py index a72bea8a496..7b92e462a2a 100644 --- a/tests/classification/test_inputs.py +++ b/tests/classification/test_inputs.py @@ -218,13 +218,6 @@ def test_threshold(): ######################################################################## -@pytest.mark.parametrize("threshold", [-0.5, 0.0, 1.0, 1.5]) -def test_incorrect_threshold(threshold): - preds, target = rand(size=(7, )), randint(high=2, size=(7, )) - with pytest.raises(ValueError): - _input_format_classification(preds, target, threshold=threshold) - - @pytest.mark.parametrize( "preds, target, num_classes, multiclass", [ @@ -234,8 +227,6 @@ def test_incorrect_threshold(threshold): (randint(high=2, size=(7, )), -randint(high=2, size=(7, )), None, None), # Preds negative integers (-randint(high=2, size=(7, )), randint(high=2, size=(7, )), None, None), - # Negative probabilities - (-rand(size=(7, )), randint(high=2, size=(7, )), None, None), # multiclass=False and target > 1 (rand(size=(7, )), randint(low=2, high=4, size=(7, )), None, False), # multiclass=False and preds integers with > 1 @@ -254,8 +245,6 @@ def test_incorrect_threshold(threshold): (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), None, None), # multiclass=False, with C dimension > 2 (_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE, )), None, False), - # Probs of multiclass preds do not sum up to 1 - (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), None, None), # Max target larger or equal to C dimension (_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE, )), None, None), # C dimension not equal to num_classes diff --git a/tests/classification/test_precision_recall.py b/tests/classification/test_precision_recall.py index 0b65ba8fda8..d627ba67b34 100644 --- a/tests/classification/test_precision_recall.py +++ b/tests/classification/test_precision_recall.py @@ -20,12 +20,14 @@ from sklearn.metrics import precision_score, recall_score from torch import Tensor, tensor -from tests.classification.inputs import _input_binary, _input_binary_prob +from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob from tests.classification.inputs import _input_multiclass as _input_mcls +from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob from tests.classification.inputs import _input_multilabel as _input_mlb +from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester @@ -175,10 +177,13 @@ def test_no_support(metric_class, metric_fn): @pytest.mark.parametrize( "preds, target, num_classes, multiclass, mdmc_average, sk_wrapper", [ + (_input_binary_logits.preds, _input_binary_logits.target, 1, None, None, _sk_prec_recall), (_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_prec_recall), (_input_binary.preds, _input_binary.target, 1, False, None, _sk_prec_recall), + (_input_mlb_logits.preds, _input_mlb_logits.target, NUM_CLASSES, None, None, _sk_prec_recall), (_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), (_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_prec_recall), + (_input_mcls_logits.preds, _input_mcls_logits.target, NUM_CLASSES, None, None, _sk_prec_recall), (_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), (_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_prec_recall), (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_multidim_multiclass), diff --git a/tests/classification/test_stat_scores.py b/tests/classification/test_stat_scores.py index 6436dfc4c23..30994a6d499 100644 --- a/tests/classification/test_stat_scores.py +++ b/tests/classification/test_stat_scores.py @@ -20,11 +20,13 @@ from sklearn.metrics import multilabel_confusion_matrix from torch import Tensor, tensor -from tests.classification.inputs import _input_binary, _input_binary_prob, _input_multiclass +from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob, _input_multiclass +from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob from tests.classification.inputs import _input_multilabel as _input_mcls +from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester @@ -136,12 +138,15 @@ def test_wrong_threshold(): @pytest.mark.parametrize( "preds, target, sk_fn, mdmc_reduce, num_classes, multiclass, top_k", [ + (_input_binary_logits.preds, _input_binary_logits.target, _sk_stat_scores, None, 1, None, None), (_input_binary_prob.preds, _input_binary_prob.target, _sk_stat_scores, None, 1, None, None), (_input_binary.preds, _input_binary.target, _sk_stat_scores, None, 1, False, None), + (_input_mlb_logits.preds, _input_mlb_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None), (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None), (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2), (_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None), (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None), + (_input_mcls_logits.preds, _input_mcls_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None), (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2), (_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None), (_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, None), diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index becaf9c2acf..1c8659ed0e7 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -37,9 +37,9 @@ class Accuracy(StatScores): Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - For multi-class and multi-dimensional multi-class data with probability predictions, the + For multi-class and multi-dimensional multi-class data with probability or logits predictions, the parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the - top-K highest probability items are considered to find the correct label. + top-K highest probability or logit score items are considered to find the correct label. For multi-label and multi-dimensional multi-class inputs, this metric computes the "global" accuracy by default, which counts all labels or sub-samples separately. This can be @@ -52,8 +52,8 @@ class Accuracy(StatScores): num_classes: Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. average: Defines the reduction that is applied. Should be one of the following: @@ -94,8 +94,8 @@ class Accuracy(StatScores): or ``'none'``, the score for the ignored class will be returned as ``nan``. top_k: - Number of highest probability predictions considered to find the correct label, relevant - only for (multi-dimensional) multi-class inputs with probability predictions. The + Number of highest probability or logit score predictions considered to find the correct label, + relevant only for (multi-dimensional) multi-class inputs. The default value (``None``) will be interpreted as 1 for these inputs. Should be left at default (``None``) for all other types of inputs. @@ -218,7 +218,7 @@ def update(self, preds: Tensor, target: Tensor): on input types. Args: - preds: Predictions from model (probabilities, or labels) + preds: Predictions from model (logits, probabilities, or labels) target: Ground truth labels """ """ returns the mode of the data (binary, multi label, multi class, multi-dim multi class) """ diff --git a/torchmetrics/classification/confusion_matrix.py b/torchmetrics/classification/confusion_matrix.py index a22ba766d5c..5dc2f03abdb 100644 --- a/torchmetrics/classification/confusion_matrix.py +++ b/torchmetrics/classification/confusion_matrix.py @@ -24,8 +24,9 @@ class ConfusionMatrix(Metric): """ Computes the `confusion matrix `_. Works with binary, - multiclass, and multilabel data. Accepts probabilities from a model output or integer class values in prediction. - Works with multi-dimensional preds and target, but it should be noted that additional dimensions will be flattened. + multiclass, and multilabel data. Accepts probabilities or logits from a model output or integer class + values in prediction. Works with multi-dimensional preds and target, but it should be noted that + additional dimensions will be flattened. Forward accepts @@ -33,7 +34,7 @@ class ConfusionMatrix(Metric): - ``target`` (long tensor): ``(N, ...)`` If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities. + to convert into integer labels. This is the case for binary and multi-label probabilities or logits. If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. @@ -51,7 +52,9 @@ class ConfusionMatrix(Metric): - ``'all'``: normalization over the whole matrix threshold: - Threshold value for binary or multi-label probabilites. default: 0.5 + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. + multilabel: determines if data is multilabel or not. compute_on_step: diff --git a/torchmetrics/classification/f_beta.py b/torchmetrics/classification/f_beta.py index a79102a45d6..c07b110d5f0 100644 --- a/torchmetrics/classification/f_beta.py +++ b/torchmetrics/classification/f_beta.py @@ -30,7 +30,7 @@ class FBeta(StatScores): {(\beta^2 * \text{precision}) + \text{recall}} Where :math:`\beta` is some positive real factor. Works with binary, multiclass, and multilabel data. - Accepts probabilities from a model output or integer class values in prediction. + Accepts logit scores or probabilities from a model output or integer class values in prediction. Works with multi-dimensional preds and target. Forward accepts @@ -39,7 +39,7 @@ class FBeta(StatScores): - ``target`` (long tensor): ``(N, ...)`` If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities. + to convert into integer labels. This is the case for binary and multi-label logits and probabilities. If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. @@ -49,8 +49,8 @@ class FBeta(StatScores): beta: Beta coefficient in the F measure. threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. average: Defines the reduction that is applied. Should be one of the following: @@ -91,12 +91,12 @@ class FBeta(StatScores): or ``'none'``, the score for the ignored class will be returned as ``nan``. top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. + Number of highest probability or logit score predictions considered to find the correct label, + relevant only for (multi-dimensional) multi-class inputs. The + default value (``None``) will be interpreted as 1 for these inputs. + + Should be left at default (``None``) for all other types of inputs. - Should be left unset (``None``) for inputs with label predictions. multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's @@ -181,12 +181,10 @@ def compute(self) -> Tensor: class F1(FBeta): """ - Computes F1 metric. F1 metrics correspond to a harmonic mean of the - precision and recall scores. + Computes F1 metric. F1 metrics correspond to a harmonic mean of the precision and recall scores. - Works with binary, multiclass, and multilabel data. - Accepts logits from a model output or integer class values in prediction. - Works with multi-dimensional preds and target. + Works with binary, multiclass, and multilabel data. Accepts logits or probabilities from a model + output or integer class values in prediction. Works with multi-dimensional preds and target. Forward accepts @@ -202,8 +200,8 @@ class F1(FBeta): num_classes: Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. average: Defines the reduction that is applied. Should be one of the following: @@ -244,12 +242,11 @@ class F1(FBeta): or ``'none'``, the score for the ignored class will be returned as ``nan``. top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. + Number of highest probability or logit score predictions considered to find the correct label, + relevant only for (multi-dimensional) multi-class inputs. The + default value (``None``) will be interpreted as 1 for these inputs. - Should be left unset (``None``) for inputs with label predictions. + Should be left at default (``None``) for all other types of inputs. multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's diff --git a/torchmetrics/classification/hamming_distance.py b/torchmetrics/classification/hamming_distance.py index 497db607ad6..90fc8302edb 100644 --- a/torchmetrics/classification/hamming_distance.py +++ b/torchmetrics/classification/hamming_distance.py @@ -40,8 +40,8 @@ class HammingDistance(Metric): Args: threshold: - Threshold probability value for transforming probability predictions to binary - (0 or 1) predictions, in the case of binary or multi-label inputs. + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. compute_on_step: Forward only calls ``update()`` and return ``None`` if this is set to ``False``. dist_sync_on_step: @@ -96,7 +96,7 @@ def update(self, preds: Tensor, target: Tensor): on input types. Args: - preds: Predictions from model (probabilities, or labels) + preds: Predictions from model (probabilities, logits or labels) target: Ground truth labels """ correct, total = _hamming_distance_update(preds, target, self.threshold) diff --git a/torchmetrics/classification/precision_recall.py b/torchmetrics/classification/precision_recall.py index f554b649be8..a8a43e25393 100644 --- a/torchmetrics/classification/precision_recall.py +++ b/torchmetrics/classification/precision_recall.py @@ -39,8 +39,8 @@ class Precision(StatScores): num_classes: Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. average: Defines the reduction that is applied. Should be one of the following: @@ -81,12 +81,12 @@ class Precision(StatScores): or ``'none'``, the score for the ignored class will be returned as ``nan``. top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. + Number of highest probability or logit score predictions considered to find the correct label, + relevant only for (multi-dimensional) multi-class inputs. The + default value (``None``) will be interpreted as 1 for these inputs. + + Should be left at default (``None``) for all other types of inputs. - Should be left unset (``None``) for inputs with label predictions. multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's @@ -200,8 +200,8 @@ class Recall(StatScores): num_classes: Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. average: Defines the reduction that is applied. Should be one of the following: @@ -242,12 +242,11 @@ class Recall(StatScores): or ``'none'``, the score for the ignored class will be returned as ``nan``. top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. + Number of highest probability or logit score predictions considered to find the correct label, + relevant only for (multi-dimensional) multi-class. The + default value (``None``) will be interpreted as 1 for these inputs. - Should be left unset (``None``) for inputs with label predictions. + Should be left at default (``None``) for all other types of inputs. multiclass: Used only in certain special cases, where you want to treat inputs as a different type diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index 6206fe13e4e..1315c76d814 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -35,16 +35,15 @@ class StatScores(Metric): Args: threshold: - Threshold probability value for transforming probability predictions to binary - (0 or 1) predictions, in the case of binary or multi-label inputs. + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. + Number of highest probability or logit score predictions considered to find the correct label, + relevant only for (multi-dimensional) multi-class inputs. The + default value (``None``) will be interpreted as 1 for these inputs. - Should be left unset (``None``) for inputs with label predictions. + Should be left at default (``None``) for all other types of inputs. reduce: Defines the reduction that is applied. Should be one of the following: @@ -200,7 +199,7 @@ def update(self, preds: Tensor, target: Tensor): on input types. Args: - preds: Predictions from model (probabilities or labels) + preds: Predictions from model (probabilities, logits or labels) target: Ground truth values """ diff --git a/torchmetrics/functional/classification/accuracy.py b/torchmetrics/functional/classification/accuracy.py index 4ce583023c2..73002ab2112 100644 --- a/torchmetrics/functional/classification/accuracy.py +++ b/torchmetrics/functional/classification/accuracy.py @@ -139,9 +139,9 @@ def accuracy( Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - For multi-class and multi-dimensional multi-class data with probability predictions, the + For multi-class and multi-dimensional multi-class data with probability or logits predictions, the parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the - top-K highest probability items are considered to find the correct label. + top-K highest probability or logits items are considered to find the correct label. For multi-label and multi-dimensional multi-class inputs, this metric computes the "global" accuracy by default, which counts all labels or sub-samples separately. This can be @@ -151,7 +151,7 @@ def accuracy( Accepts all input types listed in :ref:`references/modules:input types`. Args: - preds: Predictions from model (probabilities, or labels) + preds: Predictions from model (probabilities, logits or labels) target: Ground truth labels average: Defines the reduction that is applied. Should be one of the following: @@ -190,11 +190,11 @@ def accuracy( Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. top_k: - Number of highest probability predictions considered to find the correct label, relevant - only for (multi-dimensional) multi-class inputs with probability predictions. The + Number of highest probability or logit score predictions considered to find the correct label, + relevant only for (multi-dimensional) multi-class inputs. The default value (``None``) will be interpreted as 1 for these inputs. Should be left at default (``None``) for all other types of inputs. diff --git a/torchmetrics/functional/classification/confusion_matrix.py b/torchmetrics/functional/classification/confusion_matrix.py index e37b0a595a1..f0d45c26b38 100644 --- a/torchmetrics/functional/classification/confusion_matrix.py +++ b/torchmetrics/functional/classification/confusion_matrix.py @@ -76,11 +76,12 @@ def confusion_matrix( """ Computes the `confusion matrix `_. Works with binary, - multiclass, and multilabel data. Accepts probabilities from a model output or integer class values in prediction. - Works with multi-dimensional preds and target, but it should be noted that additional dimensions will be flattened. + multiclass, and multilabel data. Accepts probabilities or logits from a model output or integer class + values in prediction. Works with multi-dimensional preds and target, but it should be noted that + additional dimensions will be flattened. If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities. + to convert into integer labels. This is the case for binary and multi-label probabilities or logits. If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. @@ -90,7 +91,7 @@ def confusion_matrix( Args: preds: (float or long tensor), Either a ``(N, ...)`` tensor with labels or - ``(N, C, ...)`` where C is the number of classes, tensor with labels/probabilities + ``(N, C, ...)`` where C is the number of classes, tensor with labels/logits/probabilities target: ``target`` (long tensor), tensor with shape ``(N, ...)`` with ground true labels num_classes: Number of classes in the dataset. normalize: Normalization mode for confusion matrix. Choose from @@ -101,7 +102,9 @@ def confusion_matrix( - ``'all'``: normalization over the whole matrix threshold: - Threshold value for binary or multi-label probabilities. default: 0.5 + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. + multilabel: determines if data is multilabel or not. diff --git a/torchmetrics/functional/classification/f_beta.py b/torchmetrics/functional/classification/f_beta.py index c37806d0cdc..0cd843bca30 100644 --- a/torchmetrics/functional/classification/f_beta.py +++ b/torchmetrics/functional/classification/f_beta.py @@ -92,11 +92,11 @@ def fbeta( {(\beta^2 * \text{precision}) + \text{recall}} Works with binary, multiclass, and multilabel data. - Accepts probabilities from a model output or integer class values in prediction. + Accepts probabilities or logits from a model output or integer class values in prediction. Works with multi-dimensional preds and target. If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities. + to convert into integer labels. This is the case for binary and multi-label logits or probabilities. If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. @@ -105,7 +105,7 @@ def fbeta( multi-dimensional multi-class case. Accepts all inputs listed in :ref:`references/modules:input types`. Args: - preds: Predictions from model (probabilities or labels) + preds: Predictions from model (probabilities, logits or labels) target: Ground truth values average: Defines the reduction that is applied. Should be one of the following: @@ -146,13 +146,14 @@ def fbeta( num_classes: Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. Should be left unset (``None``) for inputs with label predictions. + Number of highest probability or logit score predictions considered to find the correct label, + relevant only for (multi-dimensional) multi-class inputs. The + default value (``None``) will be interpreted as 1 for these inputs. + + Should be left at default (``None``) for all other types of inputs. multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's @@ -227,11 +228,11 @@ def f1( precision and recall scores. Works with binary, multiclass, and multilabel data. - Accepts probabilities from a model output or integer class values in prediction. + Accepts probabilities or logits from a model output or integer class values in prediction. Works with multi-dimensional preds and target. If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities. + to convert into integer labels. This is the case for binary and multi-label probabilities or logits. If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. @@ -240,7 +241,7 @@ def f1( multi-dimensional multi-class case. Accepts all inputs listed in :ref:`references/modules:input types`. Args: - preds: Predictions from model (probabilities or labels) + preds: Predictions from model (probabilities, logits or labels) target: Ground truth values average: Defines the reduction that is applied. Should be one of the following: @@ -285,15 +286,15 @@ def f1( Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. + Number of highest probability or logit score predictions considered to find the correct label, + relevant only for (multi-dimensional) multi-class inputs. The + default value (``None``) will be interpreted as 1 for these inputs. + + Should be left at default (``None``) for all other types of inputs. - Should be left unset (``None``) for inputs with label predictions. multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's diff --git a/torchmetrics/functional/classification/hamming_distance.py b/torchmetrics/functional/classification/hamming_distance.py index f6ab51f0a6b..d38e8ce1939 100644 --- a/torchmetrics/functional/classification/hamming_distance.py +++ b/torchmetrics/functional/classification/hamming_distance.py @@ -55,11 +55,11 @@ def hamming_distance(preds: Tensor, target: Tensor, threshold: float = 0.5) -> T Accepts all input types listed in :ref:`references/modules:input types`. Args: - preds: Predictions from model + preds: Predictions from model (probabilities, logits or labels) target: Ground truth threshold: - Threshold probability value for transforming probability predictions to binary - (0 or 1) predictions, in the case of binary or multi-label inputs. + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. Example: >>> from torchmetrics.functional import hamming_distance diff --git a/torchmetrics/functional/classification/precision_recall.py b/torchmetrics/functional/classification/precision_recall.py index d9e87924bf4..0203e99671b 100644 --- a/torchmetrics/functional/classification/precision_recall.py +++ b/torchmetrics/functional/classification/precision_recall.py @@ -66,7 +66,7 @@ def precision( multi-dimensional multi-class case. Accepts all inputs listed in :ref:`references/modules:input types`. Args: - preds: Predictions from model (probabilities or labels) + preds: Predictions from model (probabilities, logits or labels) target: Ground truth values average: Defines the reduction that is applied. Should be one of the following: @@ -111,15 +111,14 @@ def precision( Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. + Number of highest probability or logit score predictions considered to find the correct label, + relevant only for (multi-dimensional) multi-class inputs. The + default value (``None``) will be interpreted as 1 for these inputs. - Should be left unset (``None``) for inputs with label predictions. + Should be left at default (``None``) for all other types of inputs. multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's @@ -240,7 +239,7 @@ def recall( multi-dimensional multi-class case. Accepts all inputs listed in :ref:`references/modules:input types`. Args: - preds: Predictions from model (probabilities, or labels) + preds: Predictions from model (probabilities, logits or labels) target: Ground truth values average: Defines the reduction that is applied. Should be one of the following: @@ -285,15 +284,14 @@ def recall( Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. + Number of highest probability or logit score predictions considered to find the correct label, + relevant only for (multi-dimensional) multi-class inputs. The + default value (``None``) will be interpreted as 1 for these inputs. - Should be left unset (``None``) for inputs with label predictions. + Should be left at default (``None``) for all other types of inputs. multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's @@ -398,7 +396,7 @@ def precision_recall( multi-dimensional multi-class case. Accepts all inputs listed in :ref:`references/modules:input types`. Args: - preds: Predictions from model (probabilities, or labels) + preds: Predictions from model (probabilities, logits or labels) target: Ground truth values average: Defines the reduction that is applied. Should be one of the following: @@ -443,15 +441,14 @@ def precision_recall( Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. + Number of highest probability or logit score predictions considered to find the correct label, + relevant only for (multi-dimensional) multi-class inputs. The + default value (``None``) will be interpreted as 1 for these inputs. - Should be left unset (``None``) for inputs with label predictions. + Should be left at default (``None``) for all other types of inputs. multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's diff --git a/torchmetrics/functional/classification/stat_scores.py b/torchmetrics/functional/classification/stat_scores.py index 1c86c16708b..c2e62f69aa7 100644 --- a/torchmetrics/functional/classification/stat_scores.py +++ b/torchmetrics/functional/classification/stat_scores.py @@ -159,19 +159,18 @@ def stat_scores( multi-dimensional multi-class case. Accepts all inputs listed in :ref:`references/modules:input types`. Args: - preds: Predictions from model (probabilities or labels) + preds: Predictions from model (probabilities, logits or labels) target: Ground truth values threshold: - Threshold probability value for transforming probability predictions to binary - (0 or 1) predictions, in the case of binary or multi-label inputs. + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. + Number of highest probability or logit score predictions considered to find the correct label, + relevant only for (multi-dimensional) multi-class inputs. The + default value (``None``) will be interpreted as 1 for these inputs. - Should be left unset (``None``) for inputs with label predictions. + Should be left at default (``None``) for all other types of inputs. reduce: Defines the reduction that is applied. Should be one of the following: diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 5cf0c842e13..339d1fd0215 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -44,12 +44,6 @@ def _basic_input_validation(preds: Tensor, target: Tensor, threshold: float, mul if not preds.shape[0] == target.shape[0]: raise ValueError("The `preds` and `target` should have the same first dimension.") - if preds_float and (preds.min() < 0 or preds.max() > 1): - raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.") - - if not 0 < threshold < 1: - raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}") - if multiclass is False and target.max() > 1: raise ValueError("If you set `multiclass=False`, then `target` should not exceed 1.") @@ -270,11 +264,6 @@ def _check_classification_inputs( # Check that shape/types fall into one of the cases case, implied_classes = _check_shape_and_type_consistency(preds, target) - # For (multi-dim) multi-class case with prob preds, check that preds sum up to 1 - if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and preds.is_floating_point(): - if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all(): - raise ValueError("Probabilities in `preds` must sum up to 1 across the `C` dimension.") - # Check consistency with the `C` dimension in case of multi-class data if preds.shape != target.shape: if multiclass is False and implied_classes != 2: