Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prune metric: accuracy 4/n #6515

Merged
merged 9 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

[#6547](https://github.com/PyTorchLightning/pytorch-lightning/pull/6547),

[#6515](https://github.com/PyTorchLightning/pytorch-lightning/pull/6515),

)


Expand Down
127 changes: 7 additions & 120 deletions pytorch_lightning/metrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,92 +13,14 @@
# limitations under the License.
from typing import Any, Callable, Optional

import torch
from torchmetrics import Metric
from torchmetrics import Accuracy as _Accuracy

from pytorch_lightning.metrics.functional.accuracy import _accuracy_compute, _accuracy_update
from pytorch_lightning.utilities.deprecation import deprecated


class Accuracy(Metric):
r"""
Computes `Accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`__:

.. math::
\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)

Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a
tensor of predictions.

For multi-class and multi-dimensional multi-class data with probability predictions, the
parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the
top-K highest probability items are considered to find the correct label.

For multi-label and multi-dimensional multi-class inputs, this metric computes the "global"
accuracy by default, which counts all labels or sub-samples separately. This can be
changed to subset accuracy (which requires all labels or sub-samples in the sample to
be correctly predicted) by setting ``subset_accuracy=True``.

Args:
threshold:
Threshold probability value for transforming probability predictions to binary
(0,1) predictions, in the case of binary or multi-label inputs.
top_k:
Number of highest probability predictions considered to find the correct label, relevant
only for (multi-dimensional) multi-class inputs with probability predictions. The
default value (``None``) will be interpreted as 1 for these inputs.

Should be left at default (``None``) for all other types of inputs.
subset_accuracy:
Whether to compute subset accuracy for multi-label and multi-dimensional
multi-class inputs (has no effect for other input types).

- For multi-label inputs, if the parameter is set to ``True``, then all labels for
each sample must be correctly predicted for the sample to count as correct. If it
is set to ``False``, then all labels are counted separately - this is equivalent to
flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``).

- For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all
sub-sample (on the extra axis) must be correct for the sample to be counted as correct.
If it is set to ``False``, then all sub-samples are counter separately - this is equivalent,
in the case of label predictions, to flattening the inputs beforehand (i.e.
``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter
still applies in both cases, if set.

compute_on_step:
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step
process_group:
Specify the process group on which synchronization is called.
default: ``None`` (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather

Raises:
ValueError:
If ``threshold`` is not between ``0`` and ``1``.
ValueError:
If ``top_k`` is not an ``integer`` larger than ``0``.

Example:

>>> from pytorch_lightning.metrics import Accuracy
>>> target = torch.tensor([0, 1, 2, 3])
>>> preds = torch.tensor([0, 2, 1, 3])
>>> accuracy = Accuracy()
>>> accuracy(preds, target)
tensor(0.5000)

>>> target = torch.tensor([0, 1, 2])
>>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]])
>>> accuracy = Accuracy(top_k=2)
>>> accuracy(preds, target)
tensor(0.6667)

"""
class Accuracy(_Accuracy):

@deprecated(target=_Accuracy, ver_deprecate="1.3.0", ver_remove="1.5.0")
def __init__(
self,
threshold: float = 0.5,
Expand All @@ -109,44 +31,9 @@ def __init__(
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)

self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

if not 0 < threshold < 1:
raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}")

if top_k is not None and (not isinstance(top_k, int) or top_k <= 0):
raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}")

self.threshold = threshold
self.top_k = top_k
self.subset_accuracy = subset_accuracy

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.

Args:
preds: Predictions from model (probabilities, or labels)
target: Ground truth labels
"""
This implementation refers to :class:`~torchmetrics.Accuracy`.

correct, total = _accuracy_update(
preds, target, threshold=self.threshold, top_k=self.top_k, subset_accuracy=self.subset_accuracy
)

self.correct += correct
self.total += total

def compute(self) -> torch.Tensor:
"""
Computes accuracy based on inputs passed in to ``update`` previously.
.. deprecated::
Use :class:`~torchmetrics.Accuracy`. Will be removed in v1.5.0.
"""
return _accuracy_compute(self.correct, self.total)
103 changes: 7 additions & 96 deletions pytorch_lightning/metrics/functional/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,112 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
from typing import Optional

import torch
from torchmetrics.classification.checks import _input_format_classification
from torchmetrics.utilities.enums import DataType
from torchmetrics.functional import accuracy as _accuracy


def _accuracy_update(
preds: torch.Tensor, target: torch.Tensor, threshold: float, top_k: Optional[int], subset_accuracy: bool
) -> Tuple[torch.Tensor, torch.Tensor]:

preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k)

if mode == DataType.MULTILABEL and top_k:
raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.")

if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy):
correct = (preds == target).all(dim=1).sum()
total = torch.tensor(target.shape[0], device=target.device)
elif mode == DataType.MULTILABEL and not subset_accuracy:
correct = (preds == target).sum()
total = torch.tensor(target.numel(), device=target.device)
elif mode == DataType.MULTICLASS or (mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy):
correct = (preds * target).sum()
total = target.sum()
elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy:
sample_correct = (preds * target).sum(dim=(1, 2))
correct = (sample_correct == target.shape[2]).sum()
total = torch.tensor(target.shape[0], device=target.device)

return correct, total


def _accuracy_compute(correct: torch.Tensor, total: torch.Tensor) -> torch.Tensor:
return correct.float() / total
from pytorch_lightning.utilities.deprecation import deprecated


@deprecated(target=_accuracy, ver_deprecate="1.3.0", ver_remove="1.5.0")
def accuracy(
preds: torch.Tensor,
target: torch.Tensor,
threshold: float = 0.5,
top_k: Optional[int] = None,
subset_accuracy: bool = False,
) -> torch.Tensor:
r"""Computes `Accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`_:

.. math::
\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)

Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a
tensor of predictions.

For multi-class and multi-dimensional multi-class data with probability predictions, the
parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the
top-K highest probability items are considered to find the correct label.

For multi-label and multi-dimensional multi-class inputs, this metric computes the "global"
accuracy by default, which counts all labels or sub-samples separately. This can be
changed to subset accuracy (which requires all labels or sub-samples in the sample to
be correctly predicted) by setting ``subset_accuracy=True``.

Args:
preds: Predictions from model (probabilities, or labels)
target: Ground truth labels
threshold:
Threshold probability value for transforming probability predictions to binary
(0,1) predictions, in the case of binary or multi-label inputs.
top_k:
Number of highest probability predictions considered to find the correct label, relevant
only for (multi-dimensional) multi-class inputs with probability predictions. The
default value (``None``) will be interpreted as 1 for these inputs.

Should be left at default (``None``) for all other types of inputs.
subset_accuracy:
Whether to compute subset accuracy for multi-label and multi-dimensional
multi-class inputs (has no effect for other input types).

- For multi-label inputs, if the parameter is set to ``True``, then all labels for
each sample must be correctly predicted for the sample to count as correct. If it
is set to ``False``, then all labels are counted separately - this is equivalent to
flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``).

- For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all
sub-sample (on the extra axis) must be correct for the sample to be counted as correct.
If it is set to ``False``, then all sub-samples are counter separately - this is equivalent,
in the case of label predictions, to flattening the inputs beforehand (i.e.
``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter
still applies in both cases, if set.

Raises:
ValueError:
If ``top_k`` parameter is set for ``multi-label`` inputs.

Example:

>>> from pytorch_lightning.metrics.functional import accuracy
>>> target = torch.tensor([0, 1, 2, 3])
>>> preds = torch.tensor([0, 2, 1, 3])
>>> accuracy(preds, target)
tensor(0.5000)

>>> target = torch.tensor([0, 1, 2])
>>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]])
>>> accuracy(preds, target, top_k=2)
tensor(0.6667)
"""

correct, total = _accuracy_update(preds, target, threshold, top_k, subset_accuracy)
return _accuracy_compute(correct, total)
.. deprecated::
Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0.
"""
19 changes: 16 additions & 3 deletions tests/deprecated_api/test_remove_1-5_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
import torch

from pytorch_lightning.metrics import Accuracy, MetricCollection
from pytorch_lightning.metrics.functional.accuracy import accuracy
from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot


def test_v1_5_0_metrics_utils():
def test_v1_5_metrics_utils():
x = torch.tensor([1, 2, 3])
with pytest.deprecated_call(match="It will be removed in v1.5.0"):
Borda marked this conversation as resolved.
Show resolved Hide resolved
assert torch.equal(to_onehot(x), torch.Tensor([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]).to(int))
Expand All @@ -37,12 +38,24 @@ def test_v1_5_0_metrics_utils():
assert torch.equal(to_categorical(x), torch.Tensor([1, 0]).to(int))


def test_v1_5_0_metrics_collection():
def test_v1_5_metric_accuracy():
accuracy.warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
assert accuracy(preds=torch.tensor([0, 1]), target=torch.tensor([0, 1])) == torch.tensor(1.)

Accuracy.__init__.warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
Accuracy()


def test_v1_5_metrics_collection():
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])

MetricCollection.__init__.warned = False
with pytest.deprecated_call(
match="`pytorch_lightning.metrics.metric.MetricCollection` was deprecated since v1.3.0 in favor"
" of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0."
):
metrics = MetricCollection([Accuracy()])
assert metrics(preds, target) == {'Accuracy': torch.Tensor([0.1250])[0]}
assert metrics(preds, target) == {'Accuracy': torch.tensor(0.1250)}
Loading