diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index 633e7fd8630..c29f9fb65dd 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -60,6 +60,11 @@ jobs: # prevent hanging Conda creations timeout-minutes: 10 + - name: Temp COCO env. fix + if: matrix.pytorch-version == '1.4' + run: | + pip install -q "numpy==1.20.0" # try to fix cocotools for PT 1.4 + - name: Update Environment run: | sudo apt install libsndfile1 diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f3768c7313..93a84f99888 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `pairwise_euclidean_distance` - `pairwise_linear_similarity` - `pairwise_manhatten_distance` +- Added `MAP` (mean average precision) metric to new detection package ([#467](https://github.com/PyTorchLightning/metrics/pull/467)) ### Changed diff --git a/docs/source/links.rst b/docs/source/links.rst index 497157a7cdc..4b72fb82a78 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -38,6 +38,7 @@ .. _MAPE implementation returns: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_percentage_error.html .. _mean squared logarithmic error: https://scikit-learn.org/stable/modules/model_evaluation.html#mean-squared-log-error .. _LPIPS: https://arxiv.org/abs/1801.03924 +.. _Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR): https://jonathan-hui.medium.com/map-mean-average-precision-for-object-detection-45c121a31173 .. _Tweedie Deviance Score: https://en.wikipedia.org/wiki/Tweedie_distribution#The_Tweedie_deviance .. _Permutation Invariant Training of Deep Models: https://ieeexplore.ieee.org/document/7952154 .. _Computes the Top-label Calibration Error: https://arxiv.org/pdf/1909.10155.pdf diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index b39a047ea5a..92f95a2c5ce 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -396,6 +396,18 @@ SSIM .. autoclass:: torchmetrics.SSIM :noindex: +***************** +Detection Metrics +***************** + +Object detection metrics can be used to evaluate the predicted detections with given groundtruth detections on images. + +MAP +~~~ + +.. autoclass:: torchmetrics.MAP + :noindex: + ****************** Regression Metrics ****************** diff --git a/requirements/detection.txt b/requirements/detection.txt new file mode 100644 index 00000000000..661d2e83a15 --- /dev/null +++ b/requirements/detection.txt @@ -0,0 +1,2 @@ +pycocotools>=2.0.2 +torchvision diff --git a/requirements/test.txt b/requirements/test.txt index 09d0bbc441d..b90d5b6e853 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -19,6 +19,7 @@ scikit-image>0.17.1 # add extra requirements -r image.txt -r text.txt +-r detection.txt -r audio.txt # audio diff --git a/setup.py b/setup.py index c9fdcf22f50..c6fefb67caf 100755 --- a/setup.py +++ b/setup.py @@ -26,8 +26,11 @@ def _load_py_module(fname, pkg="torchmetrics"): def _prepare_extras(): extras = { - "image": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="image.txt"), - "text": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="text.txt"), + "image": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="image.txt"), # skipcq: PYL-W0212 + "text": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="text.txt"), # skipcq: PYL-W0212 + "detection": setup_tools._load_requirements( # skipcq: PYL-W0212 + path_dir=_PATH_REQUIRE, file_name="detection.txt" + ), "audio": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="audio.txt"), } # create an 'all' keyword that install all possible denpendencies diff --git a/tests/detection/__init__.py b/tests/detection/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py new file mode 100644 index 00000000000..00fbc19848c --- /dev/null +++ b/tests/detection/test_map.py @@ -0,0 +1,284 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple + +import pytest +import torch + +from tests.helpers.testers import MetricTester +from torchmetrics.detection.map import MAP +from torchmetrics.utilities.imports import ( + _PYCOCOTOOLS_AVAILABLE, + _TORCHVISION_AVAILABLE, + _TORCHVISION_GREATER_EQUAL_0_8, +) + +Input = namedtuple("Input", ["preds", "target", "num_classes"]) + +_inputs = Input( + preds=[ + [ + dict( + boxes=torch.Tensor([[258.15, 41.29, 606.41, 285.07]]), + scores=torch.Tensor([0.236]), + labels=torch.IntTensor([4]), + ), # coco image id 42 + dict( + boxes=torch.Tensor([[61.00, 22.75, 565.00, 632.42], [12.66, 3.32, 281.26, 275.23]]), + scores=torch.Tensor([0.318, 0.726]), + labels=torch.IntTensor([3, 2]), + ), # coco image id 73 + ], + [ + dict( + boxes=torch.Tensor( + [ + [87.87, 276.25, 384.29, 379.43], + [0.00, 3.66, 142.15, 316.06], + [296.55, 93.96, 314.97, 152.79], + [328.94, 97.05, 342.49, 122.98], + [356.62, 95.47, 372.33, 147.55], + [464.08, 105.09, 495.74, 146.99], + [276.11, 103.84, 291.44, 150.72], + ] + ), + scores=torch.Tensor([0.546, 0.3, 0.407, 0.611, 0.335, 0.805, 0.953]), + labels=torch.IntTensor([4, 1, 0, 0, 0, 0, 0]), + ), # coco image id 74 + dict( + boxes=torch.Tensor([[0.00, 2.87, 601.00, 421.52]]), + scores=torch.Tensor([0.699, 0.423]), + labels=torch.IntTensor([5]), + ), # coco image id 133 + ], + ], + target=[ + [ + dict( + boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), + labels=torch.IntTensor([4]), + ), # coco image id 42 + dict( + boxes=torch.Tensor( + [ + [13.00, 22.75, 548.98, 632.42], + [1.66, 3.32, 270.26, 275.23], + ] + ), + labels=torch.IntTensor([2, 2]), + ), # coco image id 73 + ], + [ + dict( + boxes=torch.Tensor( + [ + [61.87, 276.25, 358.29, 379.43], + [2.75, 3.66, 162.15, 316.06], + [295.55, 93.96, 313.97, 152.79], + [326.94, 97.05, 340.49, 122.98], + [356.62, 95.47, 372.33, 147.55], + [462.08, 105.09, 493.74, 146.99], + [277.11, 103.84, 292.44, 150.72], + ] + ), + labels=torch.IntTensor([4, 1, 0, 0, 0, 0, 0]), + ), # coco image id 74 + dict( + boxes=torch.Tensor([[13.99, 2.87, 640.00, 421.52]]), + labels=torch.IntTensor([5]), + ), # coco image id 133 + ], + ], + num_classes=6, +) + + +def _compare_fn(preds, target) -> dict: + """Comparison function for map implementation. + + Official pycocotools results calculated from a subset of https://github.com/cocodataset/cocoapi/tree/master/results + All classes + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.706 + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.901 + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.846 + Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.689 + Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.800 + Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.701 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.592 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.716 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.716 + Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.767 + Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.800 + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.700 + + Class 0 + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.725 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.780 + + Class 1 + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.800 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.800 + + Class 2 + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.454 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.450 + + Class 3 + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = -1.000 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = -1.000 + + Class 4 + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.650 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.650 + + Class 5 + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.900 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.900 + """ + return { + "map": torch.Tensor([0.706]), + "map_50": torch.Tensor([0.901]), + "map_75": torch.Tensor([0.846]), + "map_small": torch.Tensor([0.689]), + "map_medium": torch.Tensor([0.800]), + "map_large": torch.Tensor([0.701]), + "mar_1": torch.Tensor([0.592]), + "mar_10": torch.Tensor([0.716]), + "mar_100": torch.Tensor([0.716]), + "mar_small": torch.Tensor([0.767]), + "mar_medium": torch.Tensor([0.800]), + "mar_large": torch.Tensor([0.700]), + "map_per_class": torch.Tensor([0.725, 0.800, 0.454, -1.000, 0.650, 0.900]), + "mar_100_per_class": torch.Tensor([0.780, 0.800, 0.450, -1.000, 0.650, 0.900]), + } + + +_pytest_condition = not (_PYCOCOTOOLS_AVAILABLE and _TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8) + + +@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") +class TestMAP(MetricTester): + """Test the MAP metric for object detection predictions. + + Results are compared to original values from the pycocotools implementation. + A subset of the first 10 fake predictions of the official repo is used: + https://github.com/cocodataset/cocoapi/blob/master/results/instances_val2014_fakebbox100_results.json + """ + + atol = 1e-1 + + @pytest.mark.parametrize("ddp", [False, True]) + def test_map(self, ddp): + """Test modular implementation for correctness.""" + + self.run_class_metric_test( + ddp=ddp, + preds=_inputs.preds, + target=_inputs.target, + metric_class=MAP, + sk_metric=_compare_fn, + dist_sync_on_step=False, + check_batch=False, + metric_args={"class_metrics": True}, + ) + + +# noinspection PyTypeChecker +@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") +def test_error_on_wrong_init(): + """Test class raises the expected errors.""" + + MAP() # no error + + with pytest.raises(ValueError, match="Expected argument `class_metrics` to be a boolean"): + MAP(class_metrics=0) + + +@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") +def test_error_on_wrong_input(): + """Test class input validation.""" + + metric = MAP() + + metric.update([], []) # no error + + with pytest.raises(ValueError, match="Expected argument `preds` to be of type List"): + metric.update(torch.Tensor(), []) # type: ignore + + with pytest.raises(ValueError, match="Expected argument `target` to be of type List"): + metric.update([], torch.Tensor()) # type: ignore + + with pytest.raises(ValueError, match="Expected argument `preds` and `target` to have the same length"): + metric.update([dict()], [dict(), dict()]) + + with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `boxes` key"): + metric.update( + [dict(scores=torch.Tensor(), labels=torch.IntTensor)], + [dict(boxes=torch.Tensor(), labels=torch.IntTensor())], + ) + + with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `scores` key"): + metric.update( + [dict(boxes=torch.Tensor(), labels=torch.IntTensor)], + [dict(boxes=torch.Tensor(), labels=torch.IntTensor())], + ) + + with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `labels` key"): + metric.update( + [dict(boxes=torch.Tensor(), scores=torch.IntTensor)], + [dict(boxes=torch.Tensor(), labels=torch.IntTensor())], + ) + + with pytest.raises(ValueError, match="Expected all dicts in `target` to contain the `boxes` key"): + metric.update( + [dict(boxes=torch.Tensor(), scores=torch.IntTensor, labels=torch.IntTensor)], + [dict(labels=torch.IntTensor())], + ) + + with pytest.raises(ValueError, match="Expected all dicts in `target` to contain the `labels` key"): + metric.update( + [dict(boxes=torch.Tensor(), scores=torch.IntTensor, labels=torch.IntTensor)], + [dict(boxes=torch.IntTensor())], + ) + + with pytest.raises(ValueError, match="Expected all boxes in `preds` to be of type torch.Tensor"): + metric.update( + [dict(boxes=[], scores=torch.Tensor(), labels=torch.IntTensor())], + [dict(boxes=torch.Tensor(), labels=torch.IntTensor())], + ) + + with pytest.raises(ValueError, match="Expected all scores in `preds` to be of type torch.Tensor"): + metric.update( + [dict(boxes=torch.Tensor(), scores=[], labels=torch.IntTensor())], + [dict(boxes=torch.Tensor(), labels=torch.IntTensor())], + ) + + with pytest.raises(ValueError, match="Expected all labels in `preds` to be of type torch.Tensor"): + metric.update( + [dict(boxes=torch.Tensor(), scores=torch.Tensor(), labels=[])], + [dict(boxes=torch.Tensor(), labels=torch.IntTensor())], + ) + + with pytest.raises(ValueError, match="Expected all boxes in `target` to be of type torch.Tensor"): + metric.update( + [dict(boxes=torch.Tensor(), scores=torch.Tensor(), labels=torch.IntTensor())], + [dict(boxes=[], labels=torch.IntTensor())], + ) + + with pytest.raises(ValueError, match="Expected all labels in `target` to be of type torch.Tensor"): + metric.update( + [dict(boxes=torch.Tensor(), scores=torch.Tensor(), labels=torch.IntTensor())], + [dict(boxes=torch.Tensor(), labels=[])], + ) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index f0cfbfe86a1..f5560391659 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -15,7 +15,7 @@ import pickle import sys from functools import partial -from typing import Any, Callable, Dict, Optional, Sequence +from typing import Any, Callable, Dict, List, Optional, Sequence, Union import numpy as np import pytest @@ -24,6 +24,7 @@ from torch.multiprocessing import Pool, set_start_method from torchmetrics import Metric +from torchmetrics.detection.map import MAPMetricResults try: set_start_method("spawn") @@ -83,6 +84,9 @@ def _assert_tensor(pl_result: Any, key: Optional[str] = None) -> None: if key is None: raise KeyError("Provide Key for Dict based metric results.") assert isinstance(pl_result[key], Tensor) + elif isinstance(pl_result, MAPMetricResults): + for val_index in [a for a in dir(pl_result) if not a.startswith("__")]: + assert isinstance(pl_result[val_index], Tensor) else: assert isinstance(pl_result, Tensor) @@ -104,8 +108,8 @@ def _assert_requires_grad(metric: Metric, pl_result: Any, key: Optional[str] = N def _class_test( rank: int, worldsize: int, - preds: Tensor, - target: Tensor, + preds: Union[Tensor, List[Dict[str, Tensor]]], + target: Union[Tensor, List[Dict[str, Tensor]]], metric_class: Metric, sk_metric: Callable, dist_sync_on_step: bool, @@ -139,8 +143,8 @@ def _class_test( kwargs_update: Additional keyword arguments that will be passed with preds and target when running update on the metric. """ - assert preds.shape[0] == target.shape[0] - num_batches = preds.shape[0] + assert len(preds) == len(target) + num_batches = len(preds) if not metric_args: metric_args = {} @@ -160,8 +164,11 @@ def _class_test( # move to device metric = metric.to(device) - preds = preds.to(device) - target = target.to(device) + + if isinstance(preds, torch.Tensor): + preds = preds.to(device) + target = target.to(device) + kwargs_update = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} # verify metrics work after being loaded from pickled state @@ -174,33 +181,56 @@ def _class_test( batch_result = metric(preds[i], target[i], **batch_kwargs_update) if metric.dist_sync_on_step and check_dist_sync_on_step and rank == 0: - ddp_preds = torch.cat([preds[i + r] for r in range(worldsize)]).cpu() - ddp_target = torch.cat([target[i + r] for r in range(worldsize)]).cpu() + if isinstance(preds, torch.Tensor): + ddp_preds = torch.cat([preds[i + r] for r in range(worldsize)]).cpu() + ddp_target = torch.cat([target[i + r] for r in range(worldsize)]).cpu() + else: + ddp_preds = [preds[i + r] for r in range(worldsize)] + ddp_target = [target[i + r] for r in range(worldsize)] ddp_kwargs_upd = { k: torch.cat([v[i + r] for r in range(worldsize)]).cpu() if isinstance(v, Tensor) else v for k, v in (kwargs_update if fragment_kwargs else batch_kwargs_update).items() } sk_batch_result = sk_metric(ddp_preds, ddp_target, **ddp_kwargs_upd) - _assert_allclose(batch_result, sk_batch_result, atol=atol) + if isinstance(batch_result, dict): + for key in batch_result: + _assert_allclose(batch_result, sk_batch_result[key].numpy(), atol=atol, key=key) + else: + _assert_allclose(batch_result, sk_batch_result, atol=atol) elif check_batch and not metric.dist_sync_on_step: batch_kwargs_update = { k: v.cpu() if isinstance(v, Tensor) else v for k, v in (batch_kwargs_update if fragment_kwargs else kwargs_update).items() } - sk_batch_result = sk_metric(preds[i].cpu(), target[i].cpu(), **batch_kwargs_update) - _assert_allclose(batch_result, sk_batch_result, atol=atol) + preds_ = preds[i].cpu() if isinstance(preds, torch.Tensor) else preds[i] + target_ = target[i].cpu() if isinstance(target, torch.Tensor) else target[i] + sk_batch_result = sk_metric(preds_, target_, **batch_kwargs_update) + if isinstance(batch_result, dict): + for key in batch_result.keys(): + _assert_allclose(batch_result, sk_batch_result[key].numpy(), atol=atol, key=key) + else: + _assert_allclose(batch_result, sk_batch_result, atol=atol) # check that metrics are hashable assert hash(metric) # check on all batches on all ranks result = metric.compute() - _assert_tensor(result) + if isinstance(result, dict): + for key in result.keys(): + _assert_tensor(result, key=key) + else: + _assert_tensor(result) + + if isinstance(preds, torch.Tensor): + total_preds = torch.cat([preds[i] for i in range(num_batches)]).cpu() + total_target = torch.cat([target[i] for i in range(num_batches)]).cpu() + else: + total_preds = [item for sublist in preds for item in sublist] + total_target = [item for sublist in target for item in sublist] - total_preds = torch.cat([preds[i] for i in range(num_batches)]).cpu() - total_target = torch.cat([target[i] for i in range(num_batches)]).cpu() total_kwargs_update = { k: torch.cat([v[i] for i in range(num_batches)]).cpu() if isinstance(v, Tensor) else v for k, v in kwargs_update.items() @@ -208,7 +238,11 @@ def _class_test( sk_result = sk_metric(total_preds, total_target, **total_kwargs_update) # assert after aggregation - _assert_allclose(result, sk_result, atol=atol) + if isinstance(sk_result, dict): + for key in sk_result.keys(): + _assert_allclose(result, sk_result[key].numpy(), atol=atol, key=key) + else: + _assert_allclose(result, sk_result, atol=atol) def _functional_test( @@ -357,8 +391,8 @@ def run_functional_metric_test( def run_class_metric_test( self, ddp: bool, - preds: Tensor, - target: Tensor, + preds: Union[Tensor, List[Dict]], + target: Union[Tensor, List[Dict]], metric_class: Metric, sk_metric: Callable, dist_sync_on_step: bool, diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index d4b397d248e..91639cb61b4 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -40,6 +40,7 @@ StatScores, ) from torchmetrics.collections import MetricCollection # noqa: E402 +from torchmetrics.detection import MAP # noqa: E402 from torchmetrics.image import FID, IS, KID, LPIPS, PSNR, SSIM # noqa: E402 from torchmetrics.metric import Metric # noqa: E402 from torchmetrics.regression import ( # noqa: E402 @@ -97,6 +98,7 @@ "KID", "KLDivergence", "LPIPS", + "MAP", "MatthewsCorrcoef", "MaxMetric", "MeanAbsoluteError", diff --git a/torchmetrics/detection/__init__.py b/torchmetrics/detection/__init__.py new file mode 100644 index 00000000000..f8d01bdb293 --- /dev/null +++ b/torchmetrics/detection/__init__.py @@ -0,0 +1,14 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.detection.map import MAP # noqa: F401 diff --git a/torchmetrics/detection/map.py b/torchmetrics/detection/map.py new file mode 100644 index 00000000000..3e3b9cbb154 --- /dev/null +++ b/torchmetrics/detection/map.py @@ -0,0 +1,409 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import sys +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +import torch +from torch import Tensor + +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import ( + _PYCOCOTOOLS_AVAILABLE, + _TORCHVISION_AVAILABLE, + _TORCHVISION_GREATER_EQUAL_0_8, +) + +if _TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8: + from torchvision.ops import box_convert +else: + box_convert = None + +if _PYCOCOTOOLS_AVAILABLE: + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval +else: + COCO, COCOeval = None, None + +log = logging.getLogger(__name__) + + +@dataclass +class MAPMetricResults: + """Dataclass to wrap the final mAP results.""" + + map: Tensor + map_50: Tensor + map_75: Tensor + map_small: Tensor + map_medium: Tensor + map_large: Tensor + mar_1: Tensor + mar_10: Tensor + mar_100: Tensor + mar_small: Tensor + mar_medium: Tensor + mar_large: Tensor + map_per_class: Tensor + mar_100_per_class: Tensor + + def __getitem__(self, key: str) -> Union[Tensor, List[Tensor]]: + return getattr(self, key) + + +# noinspection PyMethodMayBeStatic +class WriteToLog: + """Logging class to move logs to log.debug().""" + + def write(self, buf: str) -> None: # skipcq: PY-D0003, PYL-R0201 + for line in buf.rstrip().splitlines(): + log.debug(line.rstrip()) + + def flush(self) -> None: # skipcq: PY-D0003, PYL-R0201 + for handler in log.handlers: + handler.flush() + + def close(self) -> None: # skipcq: PY-D0003, PYL-R0201 + for handler in log.handlers: + handler.close() + + +class _hide_prints: + """Internal helper context to suppress the default output of the pycocotools package.""" + + def __init__(self) -> None: + self._original_stdout = None + + def __enter__(self) -> None: + self._original_stdout = sys.stdout # type: ignore + sys.stdout = WriteToLog() # type: ignore + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore + sys.stdout.close() + sys.stdout = self._original_stdout # type: ignore + + +def _input_validator(preds: List[Dict[str, torch.Tensor]], targets: List[Dict[str, torch.Tensor]]) -> None: + """Ensure the correct input format of `preds` and `targets`""" + + if not isinstance(preds, Sequence): + raise ValueError("Expected argument `preds` to be of type List") + if not isinstance(targets, Sequence): + raise ValueError("Expected argument `target` to be of type List") + if len(preds) != len(targets): + raise ValueError("Expected argument `preds` and `target` to have the same length") + + for k in ["boxes", "scores", "labels"]: + if any(k not in p for p in preds): + raise ValueError(f"Expected all dicts in `preds` to contain the `{k}` key") + + for k in ["boxes", "labels"]: + if any(k not in p for p in targets): + raise ValueError(f"Expected all dicts in `target` to contain the `{k}` key") + + if any(type(pred["boxes"]) is not torch.Tensor for pred in preds): + raise ValueError("Expected all boxes in `preds` to be of type torch.Tensor") + if any(type(pred["scores"]) is not torch.Tensor for pred in preds): + raise ValueError("Expected all scores in `preds` to be of type torch.Tensor") + if any(type(pred["labels"]) is not torch.Tensor for pred in preds): + raise ValueError("Expected all labels in `preds` to be of type torch.Tensor") + if any(type(target["boxes"]) is not torch.Tensor for target in targets): + raise ValueError("Expected all boxes in `target` to be of type torch.Tensor") + if any(type(target["labels"]) is not torch.Tensor for target in targets): + raise ValueError("Expected all labels in `target` to be of type torch.Tensor") + + for i, item in enumerate(targets): + if item["boxes"].size(0) != item["labels"].size(0): + raise ValueError( + f"Input boxes and labels of sample {i} in targets have a" + f" different length (expected {item['boxes'].size(0)} labels, got {item['labels'].size(0)})" + ) + for i, item in enumerate(preds): + if item["boxes"].size(0) != item["labels"].size(0) != item["scores"].size(0): + raise ValueError( + f"Input boxes, labels and scores of sample {i} in preds have a" + f" different length (expected {item['boxes'].size(0)} labels and scores," + f" got {item['labels'].size(0)} labels and {item['scores'].size(0)})" + ) + + +class MAP(Metric): + r""" + Computes the `Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR)\ + `_\ + for object detection predictions. + Optionally, the mAP and mAR values can be calculated per class. + + Predicted boxes and targets have to be in Pascal VOC format + (xmin-top left, ymin-top left, xmax-bottom right, ymax-bottom right). + See the :meth:`update` method for more information about the input format to this metric. + + .. note:: + This metric is a wrapper for the + `pycocotools `_, + which is a standard implementation for the mAP metric for object detection. Using this metric + therefore requires you to have `pycocotools` installed. Please install with ``pip install pycocotools`` or + ``pip install torchmetrics[detection]``. + + .. note:: + This metric requires you to have `torchvision` version 0.8.0 or newer installed (with corresponding + version 1.7.0 of torch or newer). Please install with ``pip install torchvision`` or + ``pip install torchmetrics[detection]``. + + .. note:: + As the pycocotools library cannot deal with tensors directly, all results have to be transfered + to the CPU, this might have an performance impact on your training. + + Args: + class_metrics: + Option to enable per-class metrics for mAP and mAR_100. Has a performance impact. default: False + compute_on_step: + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step + process_group: + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather + + Raises: + ImportError: + If ``pycocotools`` is not installed + ImportError: + If ``torchvision`` is not installed or version installed is lower than 0.8.0 + ValueError: + If ``class_metrics`` is not a boolean + """ + + def __init__( + self, + class_metrics: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ) -> None: # type: ignore + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + if not _PYCOCOTOOLS_AVAILABLE: + raise ImportError( + "`MAP` metric requires that `pycocotools` installed." + " Please install with `pip install pycocotools` or `pip install torchmetrics[detection]`" + ) + if not (_TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8): + raise ImportError( + "`MAP` metric requires that `torchvision` version 0.8.0 or newer is installed." + " Please install with `pip install torchvision` or `pip install torchmetrics[detection]`" + ) + + if not isinstance(class_metrics, bool): + raise ValueError("Expected argument `class_metrics` to be a boolean") + self.class_metrics = class_metrics + + self.add_state("detection_boxes", default=[], dist_reduce_fx=None) + self.add_state("detection_scores", default=[], dist_reduce_fx=None) + self.add_state("detection_labels", default=[], dist_reduce_fx=None) + self.add_state("groundtruth_boxes", default=[], dist_reduce_fx=None) + self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None) + + def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: # type: ignore + """Add detections and groundtruth to the metric. + + Args: + preds: A list consisting of dictionaries each containing the key-values\ + (each dictionary corresponds to a single image): + - ``boxes``: torch.FloatTensor of shape + [num_boxes, 4] containing `num_boxes` detection boxes of the format + [xmin, ymin, xmax, ymax] in absolute image coordinates. + - ``scores``: torch.FloatTensor of shape + [num_boxes] containing detection scores for the boxes. + - ``labels``: torch.IntTensor of shape + [num_boxes] containing 0-indexed detection classes for the boxes. + + target: A list consisting of dictionaries each containing the key-values\ + (each dictionary corresponds to a single image): + - ``boxes``: torch.FloatTensor of shape + [num_boxes, 4] containing `num_boxes` groundtruth boxes of the format + [xmin, ymin, xmax, ymax] in absolute image coordinates. + - ``labels``: torch.IntTensor of shape + [num_boxes] containing 1-indexed groundtruth classes for the boxes. + + Raises: + ValueError: + If ``preds`` is not of type List[Dict[str, torch.Tensor]] + ValueError: + If ``target`` is not of type List[Dict[str, torch.Tensor]] + ValueError: + If ``preds`` and ``target`` are not of the same length + ValueError: + If any of ``preds.boxes``, ``preds.scores`` + and ``preds.labels`` are not of the same length + ValueError: + If any of ``target.boxes`` and ``target.labels`` are not of the same length + ValueError: + If any box is not type float and of length 4 + ValueError: + If any class is not type int and of length 1 + ValueError: + If any score is not type float and of length 1 + """ + _input_validator(preds, target) + + for item in preds: + self.detection_boxes.append(item["boxes"]) + self.detection_scores.append(item["scores"]) + self.detection_labels.append(item["labels"]) + + for item in target: + self.groundtruth_boxes.append(item["boxes"]) + self.groundtruth_labels.append(item["labels"]) + + def compute(self) -> dict: + """Compute the `Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR)` scores. All detections added in + the `update()` method are included. + + Note: + Main `map` score is calculated with @[ IoU=0.50:0.95 | area=all | maxDets=100 ] + + Returns: + dict containing + + - map: ``torch.Tensor`` + - map_50: ``torch.Tensor`` + - map_75: ``torch.Tensor`` + - map_small: ``torch.Tensor`` + - map_medium: ``torch.Tensor`` + - map_large: ``torch.Tensor`` + - mar_1: ``torch.Tensor`` + - mar_10: ``torch.Tensor`` + - mar_100: ``torch.Tensor`` + - mar_small: ``torch.Tensor`` + - mar_medium: ``torch.Tensor`` + - mar_large: ``torch.Tensor`` + - map_per_class: ``torch.Tensor`` (-1 if class metrics are disabled) + - mar_100_per_class: ``torch.Tensor`` (-1 if class metrics are disabled) + """ + coco_target, coco_preds = COCO(), COCO() + coco_target.dataset = self._get_coco_format(self.groundtruth_boxes, self.groundtruth_labels) + coco_preds.dataset = self._get_coco_format(self.detection_boxes, self.detection_labels, self.detection_scores) + + with _hide_prints(): + coco_target.createIndex() + coco_preds.createIndex() + coco_eval = COCOeval(coco_target, coco_preds, "bbox") + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + stats = coco_eval.stats + + map_per_class_values: Tensor = torch.Tensor([-1]) + mar_100_per_class_values: Tensor = torch.Tensor([-1]) + # if class mode is enabled, evaluate metrics per class + if self.class_metrics: + map_per_class_list = [] + mar_100_per_class_list = [] + for class_id in torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist(): + coco_eval.params.catIds = [class_id] + with _hide_prints(): + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + class_stats = coco_eval.stats + + map_per_class_list.append(torch.Tensor([class_stats[0]])) + mar_100_per_class_list.append(torch.Tensor([class_stats[8]])) + map_per_class_values = torch.Tensor(map_per_class_list) + mar_100_per_class_values = torch.Tensor(mar_100_per_class_list) + + metrics = MAPMetricResults( + map=torch.Tensor([stats[0]]), + map_50=torch.Tensor([stats[1]]), + map_75=torch.Tensor([stats[2]]), + map_small=torch.Tensor([stats[3]]), + map_medium=torch.Tensor([stats[4]]), + map_large=torch.Tensor([stats[5]]), + mar_1=torch.Tensor([stats[6]]), + mar_10=torch.Tensor([stats[7]]), + mar_100=torch.Tensor([stats[8]]), + mar_small=torch.Tensor([stats[9]]), + mar_medium=torch.Tensor([stats[10]]), + mar_large=torch.Tensor([stats[11]]), + map_per_class=map_per_class_values, + mar_100_per_class=mar_100_per_class_values, + ) + return metrics.__dict__ + + def _get_coco_format( + self, boxes: List[torch.Tensor], labels: List[torch.Tensor], scores: Optional[List[torch.Tensor]] = None + ) -> Dict: + """Transforms and returns all cached targets or predictions in COCO format. + + Format is defined at https://cocodataset.org/#format-data + """ + + images = [] + annotations = [] + annotation_id = 1 # has to start with 1, otherwise COCOEval results are wrong + + boxes = [box_convert(box, in_fmt="xyxy", out_fmt="xywh") for box in boxes] + for image_id, (image_boxes, image_labels) in enumerate(zip(boxes, labels)): + image_boxes = image_boxes.cpu().tolist() + image_labels = image_labels.cpu().tolist() + + images.append({"id": image_id}) + for k, (image_box, image_label) in enumerate(zip(image_boxes, image_labels)): + if len(image_box) != 4: + raise ValueError( + f"Invalid input box of sample {image_id}, element {k} (expected 4 values, got {len(image_box)})" + ) + + if type(image_label) != int: + raise ValueError( + f"Invalid input class of sample {image_id}, element {k}" + f" (expected value of type integer, got type {type(image_label)})" + ) + + annotation = { + "id": annotation_id, + "image_id": image_id, + "bbox": image_box, + "category_id": image_label, + "area": image_box[2] * image_box[3], + "iscrowd": 0, + } + if scores is not None: + score = scores[image_id][k].cpu().tolist() + if type(score) != float: + raise ValueError( + f"Invalid input score of sample {image_id}, element {k}" + f" (expected value of type float, got type {type(score)})" + ) + annotation["score"] = score + annotations.append(annotation) + annotation_id += 1 + + classes = [ + {"id": i, "name": str(i)} + for i in torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist() + ] + return {"images": images, "annotations": annotations, "categories": classes} diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index 0f61d45860d..59185c42953 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -80,6 +80,9 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool] _SCIPY_AVAILABLE: bool = _module_available("scipy") _TORCH_FIDELITY_AVAILABLE: bool = _module_available("torch_fidelity") _LPIPS_AVAILABLE: bool = _module_available("lpips") +_PYCOCOTOOLS_AVAILABLE: bool = _module_available("pycocotools") +_TORCHVISION_AVAILABLE: bool = _module_available("torchvision") +_TORCHVISION_GREATER_EQUAL_0_8: Optional[bool] = _compare_version("torchvision", operator.ge, "0.8.0") _TQDM_AVAILABLE: bool = _module_available("tqdm") _TRANSFORMERS_AVAILABLE: bool = _module_available("transformers") _PESQ_AVAILABLE: bool = _module_available("pesq")