Skip to content

Commit

Permalink
Fix bugs in MAP related to iou_type="segm" (#1763)
Browse files Browse the repository at this point in the history
* fix + tests

* changelog

* refactor test

* fix

* fix

---------

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored May 10, 2023
1 parent cf7604d commit 41bb1fa
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed states being saved in metrics that use `register_buffer` ([#1728](https://github.com/Lightning-AI/torchmetrics/pull/1728))


- Fixed states not being correctly synced and device transfered in `MeanAveragePrecision` for `iou_type="segm"` ([#1763](https://github.com/Lightning-AI/torchmetrics/pull/1763))

## [0.11.4] - 2023-03-10

### Fixed
Expand Down
45 changes: 44 additions & 1 deletion src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
import torch.distributed as dist
from torch import IntTensor, Tensor

from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator
Expand Down Expand Up @@ -870,6 +871,48 @@ def compute(self) -> dict:
metrics.classes = torch.tensor(classes, dtype=torch.int)
return metrics

def _apply(self, fn: Callable) -> torch.nn.Module:
"""Custom apply function.
Excludes the detections and groundtruths from the casting when the iou_type is set to `segm` as the state is
no longer a tensor but a tuple.
"""
if self.iou_type == "segm":
this = super()._apply(fn, exclude_state=("detections", "groundtruths"))
else:
this = super()._apply(fn)
return this

def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Optional[Any] = None) -> None:
"""Custom sync function.
For the iou_type `segm` the detections and groundtruths are no longer tensors but tuples. Therefore, we need
to gather the list of tuples and then convert it back to a list of tuples.
"""
super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group)

if self.iou_type == "segm":
self.detections = self._gather_tuple_list(self.detections, process_group)
self.groundtruths = self._gather_tuple_list(self.groundtruths, process_group)

@staticmethod
def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] = None) -> List[Any]:
"""Gather a list of tuples over multiple devices."""
world_size = dist.get_world_size(group=process_group)
list_gathered = [None] * world_size
dist.all_gather_object(list_gathered, list_to_gather, group=process_group)

for rank in range(1, world_size):
if list_gathered[rank] != list_gathered[0]:
raise ValueError(f"Rank {rank} and Rank 0 have different values for the list to gather.")
list_merged = []
for idx in range(len(list_gathered[0])):
for rank in range(world_size):
list_merged.append(list_gathered[rank][idx])

return list_merged

def plot(
self, val: Optional[Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
Expand Down
10 changes: 9 additions & 1 deletion src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,11 +703,16 @@ def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "Metric":
out._dtype_convert = False
return out

def _apply(self, fn: Callable) -> Module:
def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> Module:
"""Overwrite _apply function such that we can also move metric states to the correct device.
This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, `.float`, `.half` etc. methods
are called. Dtype conversion is garded and will only happen through the special `set_dtype` method.
Args:
fn: the function to apply
exclude_state: list of state variables to exclude from applying the function, that then needs to be handled
by the metric class itself.
"""
this = super()._apply(fn)
fs = str(fn)
Expand All @@ -717,6 +722,9 @@ def _apply(self, fn: Callable) -> Module:

# Also apply fn to metric states and defaults
for key, value in this._defaults.items():
if key in exclude_state:
continue

if isinstance(value, Tensor):
this._defaults[key] = fn(value)
elif isinstance(value, Sequence):
Expand Down
58 changes: 57 additions & 1 deletion tests/unittests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ def _create_inputs_masks() -> Input:
"labels": IntTensor([3, 2]),
}, # 73
],
[
{
"masks": _mask_unsqueeze_bool(inputs_json["preds"][0]),
"scores": Tensor([0.236]),
"labels": IntTensor([4]),
},
{
"masks": _masks_stack_bool([inputs_json["preds"][1], inputs_json["preds"][2]]),
"scores": Tensor([0.318, 0.726]),
"labels": IntTensor([3, 2]),
}, # 73
],
],
target=[
[
Expand All @@ -59,6 +71,13 @@ def _create_inputs_masks() -> Input:
"labels": IntTensor([2, 2]),
}, # 73
],
[
{"masks": _mask_unsqueeze_bool(inputs_json["targets"][0]), "labels": IntTensor([4])}, # 42
{
"masks": _masks_stack_bool([inputs_json["targets"][1], inputs_json["targets"][2]]),
"labels": IntTensor([2, 2]),
}, # 73
],
],
)

Expand Down Expand Up @@ -357,7 +376,7 @@ def test_map_bbox(self, compute_on_cpu, ddp):
metric_args={"class_metrics": True, "compute_on_cpu": compute_on_cpu},
)

@pytest.mark.parametrize("ddp", [False])
@pytest.mark.parametrize("ddp", [False, True])
def test_map_segm(self, compute_on_cpu, ddp):
"""Test modular implementation for correctness."""
_inputs_masks = _create_inputs_masks()
Expand Down Expand Up @@ -660,3 +679,40 @@ def test_error_on_wrong_input():
[{"boxes": Tensor(), "scores": Tensor(), "labels": IntTensor()}],
[{"boxes": Tensor(), "labels": []}],
)


def _generate_random_segm_input(device):
"""Generate random inputs for mAP when iou_type=segm."""
preds = []
targets = []
for _ in range(2):
result = {}
num_preds = torch.randint(0, 10, (1,)).item()
result["scores"] = torch.rand((num_preds,), device=device)
result["labels"] = torch.randint(0, 10, (num_preds,), device=device)
result["masks"] = torch.randint(0, 2, (num_preds, 10, 10), device=device).bool()
preds.append(result)
gt = {}
num_gt = torch.randint(0, 10, (1,)).item()
gt["labels"] = torch.randint(0, 10, (num_gt,), device=device)
gt["masks"] = torch.randint(0, 2, (num_gt, 10, 10), device=device).bool()
targets.append(gt)
return preds, targets


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_device_changing():
"""See issue: https://github.com/Lightning-AI/torchmetrics/issues/1743.
Checks that the custom apply function of the metric works as expected.
"""
device = "cuda"
metric = MeanAveragePrecision(iou_type="segm").to(device)

for _ in range(2):
preds, targets = _generate_random_segm_input(device)
metric.update(preds, targets)

metric = metric.cpu()
val = metric.compute()
assert isinstance(val, dict)

0 comments on commit 41bb1fa

Please sign in to comment.