Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prune metrics: other classification 7/n #6584

Merged
merged 11 commits into from
Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

[#6573](https://github.com/PyTorchLightning/pytorch-lightning/pull/6573),

[#6584](https://github.com/PyTorchLightning/pytorch-lightning/pull/6584),

)


Expand Down
90 changes: 7 additions & 83 deletions pytorch_lightning/metrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,64 +13,14 @@
# limitations under the License.
from typing import Any, Optional

import torch
from torchmetrics import Metric
from torchmetrics import ConfusionMatrix as _ConfusionMatrix

from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update
from pytorch_lightning.utilities.deprecation import deprecated


class ConfusionMatrix(Metric):
"""
Computes the `confusion matrix
<https://scikit-learn.org/stable/modules/model_evaluation.html#confusion-matrix>`_. Works with binary,
multiclass, and multilabel data. Accepts probabilities from a model output or
integer class values in prediction. Works with multi-dimensional preds and
target.

Note:
This metric produces a multi-dimensional output, so it can not be directly logged.

Forward accepts

- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
- ``target`` (long tensor): ``(N, ...)``

If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument
to convert into integer labels. This is the case for binary and multi-label probabilities.

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.

Args:
num_classes: Number of classes in the dataset.
normalize: Normalization mode for confusion matrix. Choose from

- ``None`` or ``'none'``: no normalization (default)
- ``'true'``: normalization over the targets (most commonly used)
- ``'pred'``: normalization over the predictions
- ``'all'``: normalization over the whole matrix

threshold:
Threshold value for binary or multi-label probabilites. default: 0.5
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)

Example:

>>> from pytorch_lightning.metrics import ConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(num_classes=2)
>>> confmat(preds, target)
tensor([[2., 0.],
[1., 1.]])

"""
class ConfusionMatrix(_ConfusionMatrix):

@deprecated(target=_ConfusionMatrix, ver_deprecate="1.3.0", ver_remove="1.5.0")
def __init__(
self,
num_classes: int,
Expand All @@ -80,35 +30,9 @@ def __init__(
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):

super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
)
self.num_classes = num_classes
self.normalize = normalize
self.threshold = threshold

allowed_normalize = ('true', 'pred', 'all', 'none', None)
assert self.normalize in allowed_normalize, \
f"Argument average needs to one of the following: {allowed_normalize}"

self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum")

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.

Args:
preds: Predictions from model
target: Ground truth values
"""
confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold)
self.confmat += confmat
This implementation refers to :class:`~torchmetrics.ConfusionMatrix`.

def compute(self) -> torch.Tensor:
"""
Computes confusion matrix
.. deprecated::
Use :class:`~torchmetrics.ConfusionMatrix`. Will be removed in v1.5.0.
"""
return _confusion_matrix_compute(self.confmat, self.normalize)
180 changes: 15 additions & 165 deletions pytorch_lightning/metrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,72 +13,15 @@
# limitations under the License.
from typing import Any, Optional

import torch
from torchmetrics import Metric
from torchmetrics import F1 as _F1
from torchmetrics import FBeta as _FBeta

from pytorch_lightning.metrics.functional.f_beta import _fbeta_compute, _fbeta_update
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.deprecation import deprecated


class FBeta(Metric):
r"""
Computes `F-score <https://en.wikipedia.org/wiki/F-score>`_, specifically:

.. math::
F_\beta = (1 + \beta^2) * \frac{\text{precision} * \text{recall}}
{(\beta^2 * \text{precision}) + \text{recall}}

Where :math:`\beta` is some positive real factor. Works with binary, multiclass, and multilabel data.
Accepts probabilities from a model output or integer class values in prediction.
Works with multi-dimensional preds and target.

Forward accepts

- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
- ``target`` (long tensor): ``(N, ...)``

If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument
to convert into integer labels. This is the case for binary and multi-label probabilities.

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.

Args:
num_classes: Number of classes in the dataset.
beta: Beta coefficient in the F measure.
threshold:
Threshold value for binary or multi-label probabilities. default: 0.5

average:
- ``'micro'`` computes metric globally
- ``'macro'`` computes metric for each class and uniformly averages them
- ``'weighted'`` computes metric for each class and does a weighted-average,
where each class is weighted by their support (accounts for class imbalance)
- ``'none'`` or ``None`` computes and returns the metric per class

multilabel: If predictions are from multilabel classification.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)

Raises:
ValueError:
If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"``, ``None``.

Example:

>>> from pytorch_lightning.metrics import FBeta
>>> target = torch.tensor([0, 1, 2, 0, 1, 2])
>>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
>>> f_beta = FBeta(num_classes=3, beta=0.5)
>>> f_beta(preds, target)
tensor(0.3333)

"""
class FBeta(_FBeta):

@deprecated(target=_FBeta, ver_deprecate="1.3.0", ver_remove="1.5.0")
def __init__(
self,
num_classes: int,
Expand All @@ -90,103 +33,17 @@ def __init__(
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
)

self.num_classes = num_classes
self.beta = beta
self.threshold = threshold
self.average = average
self.multilabel = multilabel

allowed_average = ("micro", "macro", "weighted", "none", None)
if self.average not in allowed_average:
raise ValueError(
'Argument `average` expected to be one of the following:'
f' {allowed_average} but got {self.average}'
)

self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
self.add_state("actual_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.

Args:
preds: Predictions from model
target: Ground truth values
"""
true_positives, predicted_positives, actual_positives = _fbeta_update(
preds, target, self.num_classes, self.threshold, self.multilabel
)

self.true_positives += true_positives
self.predicted_positives += predicted_positives
self.actual_positives += actual_positives
This implementation refers to :class:`~torchmetrics.FBeta`.

def compute(self) -> torch.Tensor:
.. deprecated::
Use :class:`~torchmetrics.FBeta`. Will be removed in v1.5.0.
"""
Computes fbeta over state.
"""
return _fbeta_compute(
self.true_positives, self.predicted_positives, self.actual_positives, self.beta, self.average
)


class F1(FBeta):
"""
Computes F1 metric. F1 metrics correspond to a harmonic mean of the
precision and recall scores.

Works with binary, multiclass, and multilabel data.
Accepts logits from a model output or integer class values in prediction.
Works with multi-dimensional preds and target.

Forward accepts

- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
- ``target`` (long tensor): ``(N, ...)``

If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
This is the case for binary and multi-label logits.

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.

Args:
num_classes: Number of classes in the dataset.
threshold:
Threshold value for binary or multi-label logits. default: 0.5

average:
- ``'micro'`` computes metric globally
- ``'macro'`` computes metric for each class and uniformly averages them
- ``'weighted'`` computes metric for each class and does a weighted-average,
where each class is weighted by their support (accounts for class imbalance)
- ``'none'`` or ``None`` computes and returns the metric per class

multilabel: If predictions are from multilabel classification.
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)

Example:
>>> from pytorch_lightning.metrics import F1
>>> target = torch.tensor([0, 1, 2, 0, 1, 2])
>>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
>>> f1 = F1(num_classes=3)
>>> f1(preds, target)
tensor(0.3333)
"""
class F1(_F1):

@deprecated(target=_F1, ver_deprecate="1.3.0", ver_remove="1.5.0")
def __init__(
self,
num_classes: int,
Expand All @@ -197,16 +54,9 @@ def __init__(
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
if multilabel is not False:
rank_zero_warn(f'The `multilabel={multilabel}` parameter is unused and will not have any effect.')
"""
This implementation refers to :class:`~torchmetrics.F1`.

super().__init__(
num_classes=num_classes,
beta=1.0,
threshold=threshold,
average=average,
multilabel=multilabel,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
)
.. deprecated::
Use :class:`~torchmetrics.F1`. Will be removed in v1.5.0.
"""
Loading