Skip to content

Commit

Permalink
make is_differentiable as attribute (#551)
Browse files Browse the repository at this point in the history
* setter
* typing
* Apply suggestions from code review

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: thomas chaton <[email protected]>
  • Loading branch information
3 people authored Sep 27, 2021
1 parent fbdc3eb commit ac52dd7
Show file tree
Hide file tree
Showing 40 changed files with 55 additions and 174 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `half`, `double`, `float` will no longer change the dtype of the metric states. Use `metric.set_dtype` instead ([#493](https://github.com/PyTorchLightning/metrics/pull/493))


- Changed `is_differentiable` from property to a constant attribute ([#551](https://github.com/PyTorchLightning/metrics/pull/551))


### Deprecated


Expand Down
14 changes: 4 additions & 10 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -315,23 +315,17 @@ Metrics and differentiability

Metrics support backpropagation, if all computations involved in the metric calculation
are differentiable. All modular metrics have a property that determines if a metric is
differentible or not.

.. code-block:: python
@property
def is_differentiable(self) -> bool:
return True/False
differentiable or not.

However, note that the cached state is detached from the computational
graph and cannot be backpropagated. Not doing this would mean storing the computational
graph and cannot be back-propagated. Not doing this would mean storing the computational
graph for each update call, which can lead to out-of-memory errors.
In practise this means that:

.. code-block:: python
metric = MyMetric()
val = metric(pred, target) # this value can be backpropagated
val = metric.compute() # this value cannot be backpropagated
val = metric(pred, target) # this value can be back-propagated
val = metric.compute() # this value cannot be back-propagated
A functional metric is differentiable if its corresponding modular metric is differentiable.
4 changes: 3 additions & 1 deletion tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,12 @@ def _class_test(
if not metric_args:
metric_args = {}

# Instanciate lightning metric
# Instantiate lightning metric
metric = metric_class(
compute_on_step=check_dist_sync_on_step or check_batch, dist_sync_on_step=dist_sync_on_step, **metric_args
)
with pytest.raises(RuntimeError):
metric.is_differentiable = not metric.is_differentiable

# check that the metric is scriptable
if check_scriptable:
Expand Down
5 changes: 1 addition & 4 deletions torchmetrics/audio/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class PIT(Metric):
Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154.
"""

is_differentiable = True
sum_pit_metric: Tensor
total: Tensor

Expand Down Expand Up @@ -110,7 +111,3 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
def compute(self) -> Tensor:
"""Computes average PIT metric."""
return self.sum_pit_metric / self.total

@property
def is_differentiable(self) -> bool:
return True
5 changes: 1 addition & 4 deletions torchmetrics/audio/si_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class SI_SDR(Metric):
and Signal Processing (ICASSP) 2019.
"""

is_differentiable = True
sum_si_sdr: Tensor
total: Tensor

Expand Down Expand Up @@ -101,7 +102,3 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
def compute(self) -> Tensor:
"""Computes average SI-SDR."""
return self.sum_si_sdr / self.total

@property
def is_differentiable(self) -> bool:
return True
5 changes: 1 addition & 4 deletions torchmetrics/audio/si_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class SI_SNR(Metric):
696-700, doi: 10.1109/ICASSP.2018.8462116.
"""

is_differentiable = True
sum_si_snr: Tensor
total: Tensor

Expand Down Expand Up @@ -97,7 +98,3 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
def compute(self) -> Tensor:
"""Computes average SI-SNR."""
return self.sum_si_snr / self.total

@property
def is_differentiable(self) -> bool:
return True
5 changes: 1 addition & 4 deletions torchmetrics/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class SNR(Metric):
and Signal Processing (ICASSP) 2019.
"""
is_differentiable = True
sum_snr: Tensor
total: Tensor

Expand Down Expand Up @@ -107,7 +108,3 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
def compute(self) -> Tensor:
"""Computes average SNR."""
return self.sum_snr / self.total

@property
def is_differentiable(self) -> bool:
return True
5 changes: 1 addition & 4 deletions torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class Accuracy(StatScores):
tensor(0.6667)
"""
is_differentiable = False
correct: Tensor
total: Tensor

Expand Down Expand Up @@ -273,7 +274,3 @@ def compute(self) -> Tensor:
return _subset_accuracy_compute(self.correct, self.total)
tp, fp, tn, fn = self._get_final_stats()
return _accuracy_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce, self.mode)

@property
def is_differentiable(self) -> bool:
return False
7 changes: 1 addition & 6 deletions torchmetrics/classification/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class AUC(Metric):
Callback that performs the ``allgather`` operation on the metric state. When ``None``, DDP
will be used to perform the ``allgather``.
"""
is_differentiable = False
x: List[Tensor]
y: List[Tensor]

Expand Down Expand Up @@ -88,9 +89,3 @@ def compute(self) -> Tensor:
x = dim_zero_cat(self.x)
y = dim_zero_cat(self.y)
return _auc_compute(x, y, reorder=self.reorder)

@property
def is_differentiable(self) -> bool:
"""AUC metrics is considered as non differentiable so it should have `false` value for `is_differentiable`
property."""
return False
7 changes: 1 addition & 6 deletions torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class AUROC(Metric):
tensor(0.7778)
"""
is_differentiable = False
preds: List[Tensor]
target: List[Tensor]

Expand Down Expand Up @@ -183,9 +184,3 @@ def compute(self) -> Tensor:
self.average,
self.max_fpr,
)

@property
def is_differentiable(self) -> bool:
"""AUROC metrics is considered as non differentiable so it should have `false` value for
`is_differentiable` property."""
return False
5 changes: 1 addition & 4 deletions torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class AveragePrecision(Metric):
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]
"""

is_differentiable = False
preds: List[Tensor]
target: List[Tensor]

Expand Down Expand Up @@ -144,7 +145,3 @@ def compute(self) -> Union[Tensor, List[Tensor]]:
if not self.num_classes:
raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}")
return _average_precision_compute(preds, target, self.num_classes, self.pos_label, self.average)

@property
def is_differentiable(self) -> bool:
return False
7 changes: 1 addition & 6 deletions torchmetrics/classification/cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class labels.
tensor(0.5000)
"""
is_differentiable = False
confmat: Tensor

def __init__(
Expand Down Expand Up @@ -116,9 +117,3 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
def compute(self) -> Tensor:
"""Computes cohen kappa score."""
return _cohen_kappa_compute(self.confmat, self.weights)

@property
def is_differentiable(self) -> bool:
"""cohen kappa is not differentiable since the implementation is based on calculating the confusion matrix
which in general is not differentiable."""
return False
6 changes: 1 addition & 5 deletions torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class ConfusionMatrix(Metric):
[[0., 1.], [0., 1.]]])
"""

is_differentiable = False
confmat: Tensor

def __init__(
Expand Down Expand Up @@ -139,7 +139,3 @@ def compute(self) -> Tensor:
this will be a `[n_classes, 2, 2]` tensor
"""
return _confusion_matrix_compute(self.confmat, self.normalize)

@property
def is_differentiable(self) -> bool:
return False
6 changes: 2 additions & 4 deletions torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ class F1(FBeta):
tensor(0.3333)
"""

is_differentiable = False

def __init__(
self,
num_classes: Optional[int] = None,
Expand Down Expand Up @@ -297,7 +299,3 @@ def __init__(
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)

@property
def is_differentiable(self) -> bool:
return False
5 changes: 1 addition & 4 deletions torchmetrics/classification/hamming_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class HammingDistance(Metric):
tensor(0.2500)
"""
is_differentiable = False
correct: Tensor
total: Tensor

Expand Down Expand Up @@ -107,7 +108,3 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
def compute(self) -> Tensor:
"""Computes hamming distance based on inputs passed in to ``update`` previously."""
return _hamming_distance_compute(self.correct, self.total)

@property
def is_differentiable(self) -> bool:
return False
5 changes: 1 addition & 4 deletions torchmetrics/classification/hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class Hinge(Metric):
tensor([2.2333, 1.5000, 1.2333])
"""
is_differentiable = True
measure: Tensor
total: Tensor

Expand Down Expand Up @@ -124,7 +125,3 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore

def compute(self) -> Tensor:
return _hinge_compute(self.measure, self.total)

@property
def is_differentiable(self) -> bool:
return True
5 changes: 1 addition & 4 deletions torchmetrics/classification/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class IoU(ConfusionMatrix):
tensor(0.9660)
"""
is_differentiable = False

def __init__(
self,
Expand Down Expand Up @@ -104,7 +105,3 @@ def __init__(
def compute(self) -> Tensor:
"""Computes intersection over union (IoU)"""
return _iou_from_confmat(self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction)

@property
def is_differentiable(self) -> bool:
return False
5 changes: 1 addition & 4 deletions torchmetrics/classification/kl_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class KLDivergence(Metric):
tensor(0.0853)
"""
is_differentiable = True
# TODO: canot be used because if scripting
# measures: Union[List[Tensor], Tensor]
total: Tensor
Expand Down Expand Up @@ -106,7 +107,3 @@ def update(self, p: Tensor, q: Tensor) -> None: # type: ignore
def compute(self) -> Tensor:
measures = dim_zero_cat(self.measures) if self.reduction is None or self.reduction == "none" else self.measures
return _kld_compute(measures, self.total, self.reduction)

@property
def is_differentiable(self) -> bool:
return True
5 changes: 1 addition & 4 deletions torchmetrics/classification/matthews_corrcoef.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class MatthewsCorrcoef(Metric):
tensor(0.5774)
"""
is_differentiable = False
confmat: Tensor

def __init__(
Expand Down Expand Up @@ -108,7 +109,3 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
def compute(self) -> Tensor:
"""Computes matthews correlation coefficient."""
return _matthews_corrcoef_compute(self.confmat)

@property
def is_differentiable(self) -> bool:
return False
10 changes: 2 additions & 8 deletions torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class Precision(StatScores):
tensor(0.2500)
"""
is_differentiable = False

def __init__(
self,
Expand Down Expand Up @@ -168,10 +169,6 @@ def compute(self) -> Tensor:
tp, fp, _, fn = self._get_final_stats()
return _precision_compute(tp, fp, fn, self.average, self.mdmc_reduce)

@property
def is_differentiable(self) -> bool:
return False


class Recall(StatScores):
r"""
Expand Down Expand Up @@ -273,6 +270,7 @@ class Recall(StatScores):
tensor(0.2500)
"""
is_differentiable = False

def __init__(
self,
Expand Down Expand Up @@ -320,7 +318,3 @@ def compute(self) -> Tensor:
"""
tp, fp, _, fn = self._get_final_stats()
return _recall_compute(tp, fp, fn, self.average, self.mdmc_reduce)

@property
def is_differentiable(self) -> bool:
return False
5 changes: 1 addition & 4 deletions torchmetrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class PrecisionRecallCurve(Metric):
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
"""

is_differentiable = False
preds: List[Tensor]
target: List[Tensor]

Expand Down Expand Up @@ -146,7 +147,3 @@ def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], Li
if not self.num_classes:
raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}")
return _precision_recall_curve_compute(preds, target, self.num_classes, self.pos_label)

@property
def is_differentiable(self) -> bool:
return False
5 changes: 1 addition & 4 deletions torchmetrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class ROC(Metric):
tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])]
"""

is_differentiable = False
preds: List[Tensor]
target: List[Tensor]

Expand Down Expand Up @@ -166,7 +167,3 @@ def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], Li
if not self.num_classes:
raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}")
return _roc_compute(preds, target, self.num_classes, self.pos_label)

@property
def is_differentiable(self) -> bool:
return False
5 changes: 1 addition & 4 deletions torchmetrics/classification/specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class Specificity(StatScores):
tensor(0.6250)
"""
is_differentiable = False

def __init__(
self,
Expand Down Expand Up @@ -168,7 +169,3 @@ def compute(self) -> Tensor:
"""
tp, fp, tn, fn = self._get_final_stats()
return _specificity_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce)

@property
def is_differentiable(self) -> bool:
return False
Loading

0 comments on commit ac52dd7

Please sign in to comment.