Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mean average precision metric for object detection #467

Merged
merged 196 commits into from
Oct 27, 2021
Merged
Show file tree
Hide file tree
Changes from 156 commits
Commits
Show all changes
196 commits
Select commit Hold shift + click to select a range
4f9e73f
add first, very rough draft
tkupek Jul 20, 2021
4f4eb23
some basic cleanup and refactoring
Aug 19, 2021
76c5dac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2021
a826ecb
Merge branch 'master' into mean-average-precision
tkupek Aug 19, 2021
5c769cc
move requirements
SkafteNicki Aug 19, 2021
01fb237
docs
SkafteNicki Aug 19, 2021
8c09b4a
changelog
SkafteNicki Aug 19, 2021
452bcfb
conditional import
SkafteNicki Aug 19, 2021
be4b753
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2021
5174c20
Update torchmetrics/image/map.py
tkupek Aug 19, 2021
c09807d
Update torchmetrics/image/map.py
tkupek Aug 19, 2021
8398a8b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2021
10cffbd
Update torchmetrics/image/map.py
tkupek Aug 19, 2021
9a239db
Update torchmetrics/image/map.py
tkupek Aug 19, 2021
5d776d3
Update torchmetrics/image/map.py
tkupek Aug 19, 2021
040a444
Update torchmetrics/image/map.py
tkupek Aug 19, 2021
ac69201
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2021
4d6ebcc
Update torchmetrics/image/map.py
tkupek Aug 19, 2021
4debcc0
add missing dict import
Aug 19, 2021
8f6c6a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2021
a866434
add typing for preds and targets
tkupek Aug 20, 2021
555a88d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2021
51f2a3a
fix unittests for new input types
Aug 20, 2021
e0db8a6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2021
b43ab8b
Merge branch 'master' into mean-average-precision
Borda Aug 24, 2021
d7ceecb
Update torchmetrics/image/map.py
tkupek Aug 25, 2021
53e6a26
Merge branch 'master' into mean-average-precision
Borda Aug 26, 2021
a54d352
Merge branch 'master' into mean-average-precision
SkafteNicki Aug 26, 2021
5ce132e
add init
SkafteNicki Aug 26, 2021
c2e6cea
fix docstring
SkafteNicki Aug 26, 2021
d6ad53a
input validation
SkafteNicki Aug 26, 2021
e1edd2c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2021
e25cc03
dict
Borda Aug 27, 2021
16dd543
Merge branch 'master' into mean-average-precision
Borda Aug 27, 2021
c6f720b
make private
SkafteNicki Aug 27, 2021
9cb144e
typing
SkafteNicki Aug 27, 2021
d720960
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2021
0a4b445
fix typing
SkafteNicki Aug 27, 2021
50440ca
Merge branch 'mean-average-precision' of https://github.com/tkupek/py…
SkafteNicki Aug 27, 2021
f701ce4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2021
2146045
Merge branch 'master' into mean-average-precision
Borda Aug 27, 2021
56d4d8c
Merge branch 'master' into mean-average-precision
tkupek Sep 3, 2021
3466b24
Merge branch 'master' into mean-average-precision
Borda Sep 3, 2021
7837af2
minor validation fixes
Sep 4, 2021
892bc59
typing of single values in list
Sep 4, 2021
7158af0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 4, 2021
10e98c2
solve full dataset computation issue by caching predictions
Sep 5, 2021
ef965c7
attempt to fix issue with ddp caching
Sep 5, 2021
96030b9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2021
ac2b245
formatting
Sep 5, 2021
ff15700
fix deepsource issues
Sep 5, 2021
91f3cac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2021
953d38b
fix typos
Sep 5, 2021
d6a1bbd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2021
6bd3ce1
fix mypy typing
tkupek Sep 5, 2021
0d6b0af
Merge branch 'master' into mean-average-precision
Borda Sep 6, 2021
841cbf1
rewrite mAP test to be aligned with pycocotools results
Sep 7, 2021
04b47c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2021
da7d794
add missing docstrings and fix mypy
Sep 7, 2021
bd4a566
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2021
7d824e6
Merge branch 'master' into mean-average-precision
tkupek Sep 7, 2021
04e310c
fix import
Sep 7, 2021
380127c
fix docstring and unittest
Sep 7, 2021
8dff22f
add additional tests for input checking
Sep 7, 2021
33ac1e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2021
1a5ddd0
fix mypy type checking
Sep 7, 2021
e142355
fix mypy typing
Sep 7, 2021
42b7877
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2021
85577da
fix deepsource issue
Sep 7, 2021
22fabd3
fix mypy
tkupek Sep 7, 2021
d2867c6
update dict keys, fix update docstring
tkupek Sep 12, 2021
9375367
fix map docstring on boxes format
tkupek Sep 12, 2021
92ddf83
Merge branch 'master' into mean-average-precision
tkupek Sep 12, 2021
e64a53e
fix update dict keys
Sep 16, 2021
645ece1
Merge branch 'master' into mean-average-precision
tkupek Sep 16, 2021
ef6a15b
fix mypy
Sep 16, 2021
607e726
roughly fix unittests to check ddp
Sep 20, 2021
a3e82d9
refactor unittests for map
Sep 20, 2021
3a66228
add TODO
Sep 20, 2021
0d46f7a
fix atol
Sep 20, 2021
1b43847
remove TODO
Sep 20, 2021
a18d55b
fix unittests for MAP metric
tkupek Sep 20, 2021
888926a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2021
86a0419
fix ddp_sync
Sep 20, 2021
df40597
some adjustments to metric class tester
Sep 20, 2021
7a8e2a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2021
1c71993
fix dist_reduce_fx
Sep 20, 2021
4ee3c6f
fix box labels size comparison
Sep 20, 2021
6b31713
fix deepsource antipattern
Sep 21, 2021
6f2e804
Merge branch 'master' into mean-average-precision
tkupek Sep 21, 2021
bbbcabb
Merge branch 'master' into mean-average-precision
mergify[bot] Sep 21, 2021
75efc05
add ValueError for not _PYCOCOTOOLS_AVAILABLE
tkupek Sep 21, 2021
ddbf8c2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2021
a179aad
fix docstring
Sep 21, 2021
2e83e71
set compute_on_step default to False
Sep 21, 2021
1bec2fe
Merge branch 'master' into mean-average-precision
mergify[bot] Sep 21, 2021
0b6b6ee
adjusting PT 1.8.2
Borda Sep 21, 2021
450dfad
Merge branch 'master' into mean-average-precision
mergify[bot] Sep 21, 2021
5e9d6ea
Merge branch 'master' into mean-average-precision
mergify[bot] Sep 21, 2021
357cdce
Merge branch 'master' into mean-average-precision
mergify[bot] Sep 21, 2021
a694e83
Merge branch 'master' into mean-average-precision
Borda Sep 24, 2021
02efa4c
Merge branch 'master' into mean-average-precision
mergify[bot] Sep 25, 2021
1711e9c
fix class_mar_value variable naming
Sep 26, 2021
e88a931
move input validation of box, label and score length to input_validator
Sep 26, 2021
9888bdc
refactor _get_coco_format to get rid of _is_pred
Sep 26, 2021
047399d
move box, scores, labels validation
Sep 26, 2021
ecc5dcb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2021
4248a7d
Revert "move box, scores, labels validation"
Sep 26, 2021
d437190
re-implement box, scores, labels validation
Sep 26, 2021
c0cde8b
move cocoeval outputs to debug logger
Sep 26, 2021
0efafd7
add type hints for update method
Sep 26, 2021
af4b2d2
rename map_value to map and others
Sep 26, 2021
1f9718b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2021
cc4a1e6
add mypy exception for tests
Sep 26, 2021
dbffb19
change validation from list to sequence
Sep 26, 2021
b32ba25
simplify init
Sep 26, 2021
cd321e8
fix deepsource for WriteToLog
Sep 26, 2021
ae70cb4
remove unnecessary states
Sep 27, 2021
6bd529d
add more values to MAPMetricResults, extend tests
Sep 27, 2021
8d6cc16
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2021
57500a2
adjust docstring
Sep 27, 2021
2f19710
fix typing
Sep 27, 2021
6b68571
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2021
0ba7b55
fix typing
Sep 27, 2021
377101e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2021
75096a6
fix deepsource python
Sep 27, 2021
52e88c8
move map to new detection package
Sep 27, 2021
ff9865d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2021
52d9b9e
refactor num_classes parameter
Sep 27, 2021
36d4215
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2021
51e37cc
rename val_id to val_index
Sep 27, 2021
b3be8bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2021
b2d72ab
fix deepsource
Sep 27, 2021
b2f971d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2021
5ccf0c6
fix docs
Sep 27, 2021
fe83767
fix deepsource
Sep 27, 2021
a7ea040
Merge branch 'master' into mean-average-precision
mergify[bot] Sep 27, 2021
80b0696
add missing cpu transfer
Sep 27, 2021
c19a6f5
Merge branch 'master' into mean-average-precision
mergify[bot] Sep 27, 2021
335e645
add missing cpu transfer
Sep 27, 2021
aff3500
Merge branch 'master' into mean-average-precision
mergify[bot] Sep 30, 2021
20f0410
Merge branch 'master' into mean-average-precision
mergify[bot] Sep 30, 2021
32bfcd3
Merge branch 'master' into mean-average-precision
mergify[bot] Oct 5, 2021
02770bc
Merge branch 'master' into mean-average-precision
tkupek Oct 14, 2021
0841e52
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2021
8f96617
change map_per_class and mar_100_per_class to tensor
Oct 14, 2021
f7bbe94
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2021
5865773
change map metric return type to dict
Oct 14, 2021
defa383
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2021
0c56619
deepsource fixes
Oct 14, 2021
a949148
change result dict keys
Oct 14, 2021
57a5bd8
Merge branch 'master' into mean-average-precision
mergify[bot] Oct 14, 2021
1c0f791
Merge branch 'master' into mean-average-precision
mergify[bot] Oct 14, 2021
310fc58
Update torchmetrics/detection/map.py
tkupek Oct 15, 2021
bbe125c
Update torchmetrics/detection/map.py
tkupek Oct 15, 2021
b1abaa3
some fixes for generated docs
Oct 15, 2021
4df1989
Update torchmetrics/detection/map.py
tkupek Oct 15, 2021
1ee3daf
update docstring
Oct 15, 2021
e9d6c81
some more docstring fixes
Oct 18, 2021
f458f02
fix class_metrics argument and return type
tkupek Oct 18, 2021
2d9140f
fix class metrics default return value
Oct 19, 2021
d448f1e
Merge branch 'master' into mean-average-precision
mergify[bot] Oct 20, 2021
0b0e612
fix box format conversion to coco
Oct 21, 2021
ba384cf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2021
54f3605
fix torchvision import
Oct 21, 2021
82e4ce2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2021
87372b3
Merge branch 'master' into mean-average-precision
mergify[bot] Oct 22, 2021
a160a15
Merge branch 'master' into mean-average-precision
mergify[bot] Oct 22, 2021
78fd562
Merge branch 'master' into mean-average-precision
mergify[bot] Oct 22, 2021
bda42d1
Merge branch 'master' into mean-average-precision
mergify[bot] Oct 23, 2021
f8a905a
fix tests
SkafteNicki Oct 25, 2021
8b2e3f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2021
bf49109
fix condition
SkafteNicki Oct 25, 2021
42e55e1
fix testing
SkafteNicki Oct 25, 2021
289087e
fix
SkafteNicki Oct 25, 2021
97d15b6
Merge branch 'master' into mean-average-precision
mergify[bot] Oct 25, 2021
15f31e2
prepare 0.6 RC
Borda Oct 25, 2021
0618e9e
Merge branch 'master' into mean-average-precision
SkafteNicki Oct 25, 2021
61ff33c
Merge branch 'master' into mean-average-precision
SkafteNicki Oct 25, 2021
e3df5d6
Merge branch 'master' into mean-average-precision
mergify[bot] Oct 25, 2021
bdba836
update
SkafteNicki Oct 26, 2021
b7b844e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2021
4289727
add sample 133 to have equal number of test samples
Oct 27, 2021
4194eb0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 27, 2021
01a2c7e
Merge branch 'master' into mean-average-precision
mergify[bot] Oct 27, 2021
3bc39d5
Merge branch 'master' into mean-average-precision
mergify[bot] Oct 27, 2021
f80b8af
format
Borda Oct 27, 2021
7bf9cb7
typing
Borda Oct 27, 2021
4d2645d
ImportError
Borda Oct 27, 2021
c3a4559
is instance
Borda Oct 27, 2021
1a12ce6
fix test
SkafteNicki Oct 27, 2021
fc66734
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 27, 2021
3cbe630
Merge branch 'master' into mean-average-precision
Borda Oct 27, 2021
cf01eeb
req. numpy
Borda Oct 27, 2021
6542e40
CI fix
Borda Oct 27, 2021
13ec1ee
test fix
Borda Oct 27, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Learned Perceptual Image Patch Similarity (LPIPS) ([#431](https://github.com/PyTorchLightning/metrics/issues/431))


- Added `MAP` (mean average precision) metric to image package ([#467](https://github.com/PyTorchLightning/metrics/pull/467))


- Added Tweedie Deviance Score ([#499](https://github.com/PyTorchLightning/metrics/pull/499))


Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
.. _MAPE implementation returns: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_percentage_error.html
.. _mean squared logarithmic error: https://scikit-learn.org/stable/modules/model_evaluation.html#mean-squared-log-error
.. _LPIPS: https://arxiv.org/abs/1801.03924
.. _Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR): https://jonathan-hui.medium.com/map-mean-average-precision-for-object-detection-45c121a31173
.. _Tweedie Deviance Score: https://en.wikipedia.org/wiki/Tweedie_distribution#The_Tweedie_deviance
.. _Permutation Invariant Training of Deep Models: https://ieeexplore.ieee.org/document/7952154
.. _Computes the Top-label Calibration Error: https://arxiv.org/pdf/1909.10155.pdf
Expand Down
12 changes: 12 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,18 @@ SSIM
.. autoclass:: torchmetrics.SSIM
:noindex:

*****************
Detection Metrics
*****************

Object detection metrics can be used to evaluate the predicted detections with given groundtruth detections on images.

MAP
~~~

.. autoclass:: torchmetrics.MAP
:noindex:

******************
Regression Metrics
******************
Expand Down
1 change: 1 addition & 0 deletions requirements/detection.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pycocotools>=2.0.2
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ scikit-image>0.17.1
# add extra requirements
-r image.txt
-r text.txt
-r detection.txt
-r audio.txt

# audio
Expand Down
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ def _load_py_module(fname, pkg="torchmetrics"):

def _prepare_extras():
extras = {
"image": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="image.txt"),
"text": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="text.txt"),
"image": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="image.txt"), # skipcq: PYL-W0212
"text": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="text.txt"), # skipcq: PYL-W0212
"detection": setup_tools._load_requirements( # skipcq: PYL-W0212
path_dir=_PATH_REQUIRE, file_name="detection.txt"
),
"audio": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="audio.txt"),
}
# create an 'all' keyword that install all possible denpendencies
Expand Down
Empty file added tests/detection/__init__.py
Empty file.
280 changes: 280 additions & 0 deletions tests/detection/test_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import namedtuple

import pytest
import torch

from tests.helpers.testers import MetricTester
from torchmetrics.detection.map import MAP
from torchmetrics.utilities.imports import _PYCOCOTOOLS_AVAILABLE

Input = namedtuple("Input", ["preds", "target", "num_classes"])

_inputs = Input(
preds=[
{
"boxes": torch.Tensor([[258.15, 41.29, 348.26, 243.78]]),
"scores": torch.Tensor([0.236]),
"labels": torch.IntTensor([4]),
}, # coco image id 42
{
"boxes": torch.Tensor([[61, 22.75, 504, 609.67], [12.66, 3.32, 268.6, 271.91]]),
"scores": torch.Tensor([0.318, 0.726]),
"labels": torch.IntTensor([3, 2]),
}, # coco image id 73
{
"boxes": torch.Tensor(
[
[87.87, 276.25, 296.42, 103.18],
[0, 3.66, 142.15, 312.4],
[296.55, 93.96, 18.42, 58.83],
[328.94, 97.05, 13.55, 25.93],
[356.62, 95.47, 15.71, 52.08],
[464.08, 105.09, 31.66, 41.9],
[276.11, 103.84, 15.33, 46.88],
]
),
"scores": torch.Tensor([0.546, 0.3, 0.407, 0.611, 0.335, 0.805, 0.953]),
"labels": torch.IntTensor([4, 1, 0, 0, 0, 0, 0]),
}, # coco image id 74
],
target=[
{
"boxes": torch.Tensor([[214.15, 41.29, 348.26, 243.78]]),
"labels": torch.IntTensor([4]),
}, # coco image id 42
{
"boxes": torch.Tensor(
[
[13.0, 22.75, 535.98, 609.67],
[1.66, 3.32, 268.6, 271.91],
]
),
"labels": torch.IntTensor([2, 2]),
}, # coco image id 73
{
"boxes": torch.Tensor(
[
[61.87, 276.25, 296.42, 103.18],
[2.75, 3.66, 159.4, 312.4],
[295.55, 93.96, 18.42, 58.83],
[326.94, 97.05, 13.55, 25.93],
[356.62, 95.47, 15.71, 52.08],
[462.08, 105.09, 31.66, 41.9],
[277.11, 103.84, 15.33, 46.88],
]
),
"labels": torch.IntTensor([4, 1, 0, 0, 0, 0, 0]),
}, # coco image id 74
],
num_classes=5,
)


def _compare_fn(preds, target) -> dict:
"""Comparison function for map implementation.

Official pycocotools results calculated from a subset of https://github.com/cocodataset/cocoapi/tree/master/results
All classes
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.658
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.876
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.807
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.689
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.800
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.635
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.515
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.670
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.670
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.767
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.800
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.633

Class 0
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.725
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.780

Class 1
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.800
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.800

Class 2
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.454
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.450

Class 3
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = -1.000

Class 4
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.650
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.650
"""
return {
"map": torch.Tensor([0.658]),
"map_50": torch.Tensor([0.876]),
"map_75": torch.Tensor([0.807]),
"map_small": torch.Tensor([0.689]),
"map_medium": torch.Tensor([0.800]),
"map_large": torch.Tensor([0.635]),
"mar_1": torch.Tensor([0.515]),
"mar_10": torch.Tensor([0.670]),
"mar_100": torch.Tensor([0.670]),
"mar_small": torch.Tensor([0.767]),
"mar_medium": torch.Tensor([0.800]),
"mar_large": torch.Tensor([0.633]),
"map_per_class": torch.Tensor([0.725, 0.800, 0.454, -1.000, 0.650]),
"mar_100_per_class": torch.Tensor([0.780, 0.800, 0.450, -1.000, 0.650]),
}


@pytest.mark.skipif(not _PYCOCOTOOLS_AVAILABLE, reason="test requires that pycocotools is installed")
class TestMAP(MetricTester):
"""Test the MAP metric for object detection predictions.

Results are compared to original values from the pycocotools implementation.
A subset of the first 10 fake predictions of the official repo is used:
https://github.com/cocodataset/cocoapi/blob/master/results/instances_val2014_fakebbox100_results.json
"""

atol = 1e-3

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_map(self, ddp, dist_sync_on_step):
"""Test modular implementation for correctness."""

self.run_class_metric_test(
ddp=ddp,
preds=_inputs.preds,
target=_inputs.target,
metric_class=MAP,
sk_metric=_compare_fn,
dist_sync_on_step=dist_sync_on_step,
metric_args={"class_metrics": True},
)


# noinspection PyTypeChecker
@pytest.mark.skipif(not _PYCOCOTOOLS_AVAILABLE, reason="test requires that pycocotools is installed")
def test_error_on_wrong_init():
"""Test class raises the expected errors."""

MAP() # no error

with pytest.raises(ValueError, match="Expected argument `class_metrics` to be a boolean"):
MAP(class_metrics=0)


@pytest.mark.skipif(not _PYCOCOTOOLS_AVAILABLE, reason="test requires that pycocotools is installed")
def test_error_on_wrong_input():
"""Test class input validation."""

metric = MAP()

metric.update([], []) # no error

with pytest.raises(ValueError, match="Expected argument `preds` to be of type List"):
metric.update(torch.Tensor(), []) # type: ignore

with pytest.raises(ValueError, match="Expected argument `target` to be of type List"):
metric.update([], torch.Tensor()) # type: ignore

with pytest.raises(ValueError, match="Expected argument `preds` and `target` to have the same length"):
metric.update([{}], [{}, {}])

with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `boxes` key"):
metric.update(
[{"scores": torch.Tensor(), "labels": torch.IntTensor}],
[{"boxes": torch.Tensor(), "labels": torch.IntTensor()}],
)

with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `scores` key"):
metric.update(
[{"boxes": torch.Tensor(), "labels": torch.IntTensor}],
[{"boxes": torch.Tensor(), "labels": torch.IntTensor()}],
)

with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `labels` key"):
metric.update(
[{"boxes": torch.Tensor(), "scores": torch.IntTensor}],
[{"boxes": torch.Tensor(), "labels": torch.IntTensor()}],
)

with pytest.raises(ValueError, match="Expected all dicts in `target` to contain the `boxes` key"):
metric.update(
[
{
"boxes": torch.Tensor(),
"scores": torch.IntTensor,
"labels": torch.IntTensor,
}
],
[{"labels": torch.IntTensor()}],
)

with pytest.raises(ValueError, match="Expected all dicts in `target` to contain the `labels` key"):
metric.update(
[
{
"boxes": torch.Tensor(),
"scores": torch.IntTensor,
"labels": torch.IntTensor,
}
],
[{"boxes": torch.IntTensor()}],
)

with pytest.raises(ValueError, match="Expected all boxes in `preds` to be of type torch.Tensor"):
metric.update(
[{"boxes": [], "scores": torch.Tensor(), "labels": torch.IntTensor()}],
[{"boxes": torch.Tensor(), "labels": torch.IntTensor()}],
)

with pytest.raises(ValueError, match="Expected all scores in `preds` to be of type torch.Tensor"):
metric.update(
[{"boxes": torch.Tensor(), "scores": [], "labels": torch.IntTensor()}],
[{"boxes": torch.Tensor(), "labels": torch.IntTensor()}],
)

with pytest.raises(ValueError, match="Expected all labels in `preds` to be of type torch.Tensor"):
metric.update(
[{"boxes": torch.Tensor(), "scores": torch.Tensor(), "labels": []}],
[{"boxes": torch.Tensor(), "labels": torch.IntTensor()}],
)

with pytest.raises(ValueError, match="Expected all boxes in `target` to be of type torch.Tensor"):
metric.update(
[
{
"boxes": torch.Tensor(),
"scores": torch.Tensor(),
"labels": torch.IntTensor(),
}
],
[{"boxes": [], "labels": torch.IntTensor()}],
)

with pytest.raises(ValueError, match="Expected all labels in `target` to be of type torch.Tensor"):
metric.update(
[
{
"boxes": torch.Tensor(),
"scores": torch.Tensor(),
"labels": torch.IntTensor(),
}
],
[{"boxes": torch.Tensor(), "labels": []}],
)
Loading