Skip to content

Commit

Permalink
fix default vals for functional forms
Browse files Browse the repository at this point in the history
  • Loading branch information
rittik9 committed Jan 8, 2025
1 parent de8bd1e commit 8b1ad0c
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 38 deletions.
10 changes: 5 additions & 5 deletions src/torchmetrics/functional/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def multiclass_hamming_distance(
preds: Tensor,
target: Tensor,
num_classes: int,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
top_k: int = 1,
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
Expand Down Expand Up @@ -231,7 +231,7 @@ def multiclass_hamming_distance(
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_hamming_distance(preds, target, num_classes=3)
tensor(0.1667)
tensor(0.2500)
>>> multiclass_hamming_distance(preds, target, num_classes=3, average=None)
tensor([0.5000, 0.0000, 0.0000])
Expand All @@ -243,7 +243,7 @@ def multiclass_hamming_distance(
... [0.71, 0.09, 0.20],
... [0.05, 0.82, 0.13]])
>>> multiclass_hamming_distance(preds, target, num_classes=3)
tensor(0.1667)
tensor(0.2500)
>>> multiclass_hamming_distance(preds, target, num_classes=3, average=None)
tensor([0.5000, 0.0000, 0.0000])
Expand All @@ -252,7 +252,7 @@ def multiclass_hamming_distance(
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> multiclass_hamming_distance(preds, target, num_classes=3, multidim_average='samplewise')
tensor([0.5000, 0.7222])
tensor([0.5000, 0.6667])
>>> multiclass_hamming_distance(preds, target, num_classes=3, multidim_average='samplewise', average=None)
tensor([[0.0000, 1.0000, 0.5000],
[1.0000, 0.6667, 0.5000]])
Expand All @@ -273,7 +273,7 @@ def multilabel_hamming_distance(
target: Tensor,
num_labels: int,
threshold: float = 0.5,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def multiclass_negative_predictive_value(
preds: Tensor,
target: Tensor,
num_classes: int,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
top_k: int = 1,
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
Expand Down Expand Up @@ -208,7 +208,7 @@ def multiclass_negative_predictive_value(
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_negative_predictive_value(preds, target, num_classes=3)
tensor(0.8889)
tensor(0.8750)
>>> multiclass_negative_predictive_value(preds, target, num_classes=3, average=None)
tensor([0.6667, 1.0000, 1.0000])
Expand All @@ -220,7 +220,7 @@ def multiclass_negative_predictive_value(
... [0.71, 0.09, 0.20],
... [0.05, 0.82, 0.13]])
>>> multiclass_negative_predictive_value(preds, target, num_classes=3)
tensor(0.8889)
tensor(0.8750)
>>> multiclass_negative_predictive_value(preds, target, num_classes=3, average=None)
tensor([0.6667, 1.0000, 1.0000])
Expand All @@ -229,7 +229,7 @@ def multiclass_negative_predictive_value(
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> multiclass_negative_predictive_value(preds, target, num_classes=3, multidim_average='samplewise')
tensor([0.7833, 0.6556])
tensor([0.7500, 0.6667])
>>> multiclass_negative_predictive_value(
... preds, target, num_classes=3, multidim_average='samplewise', average=None
... )
Expand All @@ -254,7 +254,7 @@ def multilabel_negative_predictive_value(
target: Tensor,
num_labels: int,
threshold: float = 0.5,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
Expand Down Expand Up @@ -320,7 +320,7 @@ def multilabel_negative_predictive_value(
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_negative_predictive_value(preds, target, num_labels=3)
tensor(0.5000)
tensor(0.6667)
>>> multilabel_negative_predictive_value(preds, target, num_labels=3, average=None)
tensor([1.0000, 0.5000, 0.0000])
Expand All @@ -329,7 +329,7 @@ def multilabel_negative_predictive_value(
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_negative_predictive_value(preds, target, num_labels=3)
tensor(0.5000)
tensor(0.6667)
>>> multilabel_negative_predictive_value(preds, target, num_labels=3, average=None)
tensor([1.0000, 0.5000, 0.0000])
Expand All @@ -339,7 +339,7 @@ def multilabel_negative_predictive_value(
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> multilabel_negative_predictive_value(preds, target, num_labels=3, multidim_average='samplewise')
tensor([0.0000, 0.1667])
tensor([0.0000, 0.2500])
>>> multilabel_negative_predictive_value(
... preds, target, num_labels=3, multidim_average='samplewise', average=None
... )
Expand Down
26 changes: 13 additions & 13 deletions src/torchmetrics/functional/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def multiclass_precision(
preds: Tensor,
target: Tensor,
num_classes: int,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
top_k: int = 1,
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
Expand Down Expand Up @@ -209,7 +209,7 @@ def multiclass_precision(
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_precision(preds, target, num_classes=3)
tensor(0.8333)
tensor(0.7500)
>>> multiclass_precision(preds, target, num_classes=3, average=None)
tensor([1.0000, 0.5000, 1.0000])
Expand All @@ -221,7 +221,7 @@ def multiclass_precision(
... [0.71, 0.09, 0.20],
... [0.05, 0.82, 0.13]])
>>> multiclass_precision(preds, target, num_classes=3)
tensor(0.8333)
tensor(0.7500)
>>> multiclass_precision(preds, target, num_classes=3, average=None)
tensor([1.0000, 0.5000, 1.0000])
Expand All @@ -230,7 +230,7 @@ def multiclass_precision(
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> multiclass_precision(preds, target, num_classes=3, multidim_average='samplewise')
tensor([0.3889, 0.2778])
tensor([0.5000, 0.3333])
>>> multiclass_precision(preds, target, num_classes=3, multidim_average='samplewise', average=None)
tensor([[0.6667, 0.0000, 0.5000],
[0.0000, 0.5000, 0.3333]])
Expand Down Expand Up @@ -261,7 +261,7 @@ def multilabel_precision(
target: Tensor,
num_labels: int,
threshold: float = 0.5,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
Expand Down Expand Up @@ -326,7 +326,7 @@ def multilabel_precision(
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_precision(preds, target, num_labels=3)
tensor(0.5000)
tensor(0.6667)
>>> multilabel_precision(preds, target, num_labels=3, average=None)
tensor([1.0000, 0.0000, 0.5000])
Expand All @@ -335,7 +335,7 @@ def multilabel_precision(
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_precision(preds, target, num_labels=3)
tensor(0.5000)
tensor(0.6667)
>>> multilabel_precision(preds, target, num_labels=3, average=None)
tensor([1.0000, 0.0000, 0.5000])
Expand All @@ -345,7 +345,7 @@ def multilabel_precision(
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> multilabel_precision(preds, target, num_labels=3, multidim_average='samplewise')
tensor([0.3333, 0.0000])
tensor([0.4000, 0.0000])
>>> multilabel_precision(preds, target, num_labels=3, multidim_average='samplewise', average=None)
tensor([[0.5000, 0.5000, 0.0000],
[0.0000, 0.0000, 0.0000]])
Expand Down Expand Up @@ -451,7 +451,7 @@ def multiclass_recall(
preds: Tensor,
target: Tensor,
num_classes: int,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
top_k: int = 1,
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
Expand Down Expand Up @@ -519,7 +519,7 @@ def multiclass_recall(
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_recall(preds, target, num_classes=3)
tensor(0.8333)
tensor(0.7500)
>>> multiclass_recall(preds, target, num_classes=3, average=None)
tensor([0.5000, 1.0000, 1.0000])
Expand All @@ -531,7 +531,7 @@ def multiclass_recall(
... [0.71, 0.09, 0.20],
... [0.05, 0.82, 0.13]])
>>> multiclass_recall(preds, target, num_classes=3)
tensor(0.8333)
tensor(0.7500)
>>> multiclass_recall(preds, target, num_classes=3, average=None)
tensor([0.5000, 1.0000, 1.0000])
Expand All @@ -540,7 +540,7 @@ def multiclass_recall(
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> multiclass_recall(preds, target, num_classes=3, multidim_average='samplewise')
tensor([0.5000, 0.2778])
tensor([0.5000, 0.3333])
>>> multiclass_recall(preds, target, num_classes=3, multidim_average='samplewise', average=None)
tensor([[1.0000, 0.0000, 0.5000],
[0.0000, 0.3333, 0.5000]])
Expand Down Expand Up @@ -571,7 +571,7 @@ def multilabel_recall(
target: Tensor,
num_labels: int,
threshold: float = 0.5,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
Expand Down
10 changes: 5 additions & 5 deletions src/torchmetrics/functional/classification/specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def multiclass_specificity(
preds: Tensor,
target: Tensor,
num_classes: int,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
top_k: int = 1,
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
Expand Down Expand Up @@ -198,7 +198,7 @@ def multiclass_specificity(
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_specificity(preds, target, num_classes=3)
tensor(0.8889)
tensor(0.8750)
>>> multiclass_specificity(preds, target, num_classes=3, average=None)
tensor([1.0000, 0.6667, 1.0000])
Expand All @@ -210,7 +210,7 @@ def multiclass_specificity(
... [0.71, 0.09, 0.20],
... [0.05, 0.82, 0.13]])
>>> multiclass_specificity(preds, target, num_classes=3)
tensor(0.8889)
tensor(0.8750)
>>> multiclass_specificity(preds, target, num_classes=3, average=None)
tensor([1.0000, 0.6667, 1.0000])
Expand All @@ -219,7 +219,7 @@ def multiclass_specificity(
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> multiclass_specificity(preds, target, num_classes=3, multidim_average='samplewise')
tensor([0.7500, 0.6556])
tensor([0.7500, 0.6667])
>>> multiclass_specificity(preds, target, num_classes=3, multidim_average='samplewise', average=None)
tensor([[0.7500, 0.7500, 0.7500],
[0.8000, 0.6667, 0.5000]])
Expand All @@ -240,7 +240,7 @@ def multilabel_specificity(
target: Tensor,
num_labels: int,
threshold: float = 0.5,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
Expand Down
14 changes: 7 additions & 7 deletions src/torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def binary_stat_scores(
def _multiclass_stat_scores_arg_validation(
num_classes: int,
top_k: int = 1,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
zero_division: float = 0,
Expand Down Expand Up @@ -369,7 +369,7 @@ def _multiclass_stat_scores_update(
target: Tensor,
num_classes: int,
top_k: int = 1,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
Expand Down Expand Up @@ -450,7 +450,7 @@ def _multiclass_stat_scores_compute(
fp: Tensor,
tn: Tensor,
fn: Tensor,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
multidim_average: Literal["global", "samplewise"] = "global",
) -> Tensor:
"""Stack statistics and compute support also.
Expand Down Expand Up @@ -478,7 +478,7 @@ def multiclass_stat_scores(
preds: Tensor,
target: Tensor,
num_classes: int,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
top_k: int = 1,
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
Expand Down Expand Up @@ -591,7 +591,7 @@ def multiclass_stat_scores(
def _multilabel_stat_scores_arg_validation(
num_labels: int,
threshold: float = 0.5,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
zero_division: float = 0,
Expand Down Expand Up @@ -715,7 +715,7 @@ def _multilabel_stat_scores_compute(
fp: Tensor,
tn: Tensor,
fn: Tensor,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
multidim_average: Literal["global", "samplewise"] = "global",
) -> Tensor:
"""Stack statistics and compute support also.
Expand All @@ -742,7 +742,7 @@ def multilabel_stat_scores(
target: Tensor,
num_labels: int,
threshold: float = 0.5,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
Expand Down

0 comments on commit 8b1ad0c

Please sign in to comment.