diff --git a/anomalib/utils/metrics/__init__.py b/anomalib/utils/metrics/__init__.py index 57937d11ba..c0c44d4e4d 100644 --- a/anomalib/utils/metrics/__init__.py +++ b/anomalib/utils/metrics/__init__.py @@ -18,8 +18,9 @@ from .collection import AnomalibMetricCollection from .min_max import MinMax from .optimal_f1 import OptimalF1 +from .pro import PRO -__all__ = ["AUROC", "AUPR", "AUPRO", "OptimalF1", "AdaptiveThreshold", "AnomalyScoreDistribution", "MinMax"] +__all__ = ["AUROC", "AUPR", "AUPRO", "OptimalF1", "AdaptiveThreshold", "AnomalyScoreDistribution", "MinMax", "PRO"] def get_metrics(config: Union[ListConfig, DictConfig]) -> Tuple[AnomalibMetricCollection, AnomalibMetricCollection]: diff --git a/anomalib/utils/metrics/aupro.py b/anomalib/utils/metrics/aupro.py index f973efe0eb..58f69d4e3d 100644 --- a/anomalib/utils/metrics/aupro.py +++ b/anomalib/utils/metrics/aupro.py @@ -6,13 +6,17 @@ from typing import Any, Callable, List, Optional, Tuple import torch -from kornia.contrib import connected_components from matplotlib.figure import Figure from torch import Tensor from torchmetrics import Metric from torchmetrics.functional import auc, roc from torchmetrics.utilities.data import dim_zero_cat +from anomalib.utils.metrics.pro import ( + connected_components_cpu, + connected_components_gpu, +) + from .plotting_utils import plot_figure @@ -80,9 +84,10 @@ def _compute(self) -> Tuple[Tensor, Tensor]: ) target = target.unsqueeze(1) # kornia expects N1HW format target = target.type(torch.float) # kornia expects FloatTensor - cca = connected_components( - target, num_iterations=1000 - ) # Need higher thresholds this to avoid oversegmentation. + if target.is_cuda: + cca = connected_components_gpu(target) + else: + cca = connected_components_cpu(target) preds = preds.flatten() cca = cca.flatten() diff --git a/anomalib/utils/metrics/pro.py b/anomalib/utils/metrics/pro.py new file mode 100644 index 0000000000..6cb4ff96da --- /dev/null +++ b/anomalib/utils/metrics/pro.py @@ -0,0 +1,112 @@ +"""Implementation of PRO metric based on TorchMetrics.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +import cv2 +import numpy as np +import torch +from kornia.contrib import connected_components +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.functional import recall +from torchmetrics.utilities.data import dim_zero_cat + + +class PRO(Metric): + """Per-Region Overlap (PRO) Score.""" + + target: List[Tensor] + preds: List[Tensor] + + def __init__(self, threshold: float = 0.5, **kwargs) -> None: + super().__init__(**kwargs) + self.threshold = threshold + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, predictions: Tensor, targets: Tensor) -> None: + """Compute the PRO score for the current batch.""" + + self.target.append(targets) + self.preds.append(predictions) + + def compute(self) -> Tensor: + """Compute the macro average of the PRO score across all regions in all batches.""" + target = dim_zero_cat(self.target) + preds = dim_zero_cat(self.preds) + + if target.is_cuda: + comps = connected_components_gpu(target.unsqueeze(1)) + else: + comps = connected_components_cpu(target.unsqueeze(1)) + pro = pro_score(preds, comps, threshold=self.threshold) + return pro + + +def pro_score(predictions: Tensor, comps: Tensor, threshold: float = 0.5) -> Tensor: + """Calculate the PRO score for a batch of predictions. + + Args: + predictions (Tensor): Predicted anomaly masks (Bx1xHxW) + comps: (Tensor): Labeled connected components (BxHxW). The components should be labeled from 0 to N + threshold (float): When predictions are passed as float, the threshold is used to binarize the predictions. + + Returns: + Tensor: Scalar value representing the average PRO score for the input batch. + """ + if predictions.dtype == torch.float: + predictions = predictions > threshold + + n_comps = len(comps.unique()) + + preds = comps.clone() + preds[~predictions] = 0 + if n_comps == 1: # only background + return torch.Tensor([1.0]) + pro = recall(preds.flatten(), comps.flatten(), num_classes=n_comps, average="macro", ignore_index=0) + return pro + + +def connected_components_gpu(binary_input: Tensor, num_iterations: int = 1000) -> Tensor: + """Perform connected component labeling on GPU and remap the labels from 0 to N. + + Args: + binary_input (Tensor): Binary input data from which we want to extract connected components (Bx1xHxW) + num_iterations (int): Number of iterations used in the connected component computation. + + Returns: + Tensor: Components labeled from 0 to N. + """ + components = connected_components(binary_input, num_iterations=num_iterations) + + # remap component values from 0 to N + labels = components.unique() + for new_label, old_label in enumerate(labels): + components[components == old_label] = new_label + + return components.int() + + +def connected_components_cpu(image: Tensor) -> Tensor: + """Connected component labeling on CPU. + + Args: + image (Tensor): Binary input data from which we want to extract connected components (Bx1xHxW) + + Returns: + Tensor: Components labeled from 0 to N. + """ + components = torch.zeros_like(image) + label_idx = 1 + for i, mask in enumerate(image): + mask = mask.squeeze().numpy().astype(np.uint8) + _, comps = cv2.connectedComponents(mask) + # remap component values to make sure every component has a unique value when outputs are concatenated + for label in np.unique(comps)[1:]: + components[i, 0, ...][np.where(comps == label)] = label_idx + label_idx += 1 + return components.int() diff --git a/tests/pre_merge/utils/metrics/test_aupro.py b/tests/pre_merge/utils/metrics/test_aupro.py index e4bcfcad66..88466d0566 100644 --- a/tests/pre_merge/utils/metrics/test_aupro.py +++ b/tests/pre_merge/utils/metrics/test_aupro.py @@ -13,21 +13,17 @@ def pytest_generate_tests(metafunc): torch.tensor( [ [ - [ - [0, 0, 0, 1, 0, 0, 0], - ] - * 400, + [0, 0, 0, 1, 0, 0, 0], ] + * 400, ] ), torch.tensor( [ [ - [ - [0, 1, 0, 1, 0, 1, 0], - ] - * 400, + [0, 1, 0, 1, 0, 1, 0], ] + * 400, ] ), ] diff --git a/tests/pre_merge/utils/metrics/test_pro.py b/tests/pre_merge/utils/metrics/test_pro.py new file mode 100644 index 0000000000..25fa07a76c --- /dev/null +++ b/tests/pre_merge/utils/metrics/test_pro.py @@ -0,0 +1,80 @@ +import torch +from torch import Tensor +from torchvision.transforms import RandomAffine + +from anomalib.data.utils import random_2d_perlin +from anomalib.utils.metrics.pro import ( + PRO, + connected_components_cpu, + connected_components_gpu, +) + + +def test_pro(): + """Checks if PRO metric computes the (macro) average of the per-region overlap.""" + + labels = Tensor( + [ + [ + [0, 0, 0, 0, 0], + [1, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 1, 1, 0, 0], + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 0], + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + ] + ] + ) + + preds = (torch.arange(10) / 10) + 0.05 + preds = preds.unsqueeze(1).repeat(1, 5).view(1, 1, 10, 5) + + thresholds = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] + targets = [1.0, 0.8, 0.6, 0.4, 0.2, 0.0] + for threshold, target in zip(thresholds, targets): + pro = PRO(threshold=threshold) + pro.update(preds, labels) + assert pro.compute() == target + + +def test_device_consistency(): + """Test if the pro metric yields the same results between cpu and gpu.""" + + transform = RandomAffine(5, None, (0.95, 1.05), 5) + + batch = torch.zeros((32, 256, 256)) + for i in range(batch.shape[0]): + batch[i, ...] = random_2d_perlin((256, 256), (torch.tensor(4), torch.tensor(4))) > 0.5 + + preds = transform(batch).unsqueeze(1) + + pro_cpu = PRO() + pro_gpu = PRO() + + pro_cpu.update(preds.cpu(), batch.cpu()) + pro_gpu.update(preds.cuda(), batch.cuda()) + + assert torch.isclose(pro_cpu.compute(), pro_gpu.compute().cpu()) + + +def test_connected_component_labeling(): + """Tests if the connected component labeling algorithms on cpu and gpu yield the same result.""" + + # generate batch of random binary images using perlin noise + batch = torch.zeros((32, 1, 256, 256)) + for i in range(batch.shape[0]): + batch[i, ...] = random_2d_perlin((256, 256), (torch.tensor(4), torch.tensor(4))) > 0.5 + + # get connected component results on both cpu and gpu + cc_cpu = connected_components_cpu(batch.cpu()) + cc_gpu = connected_components_gpu(batch.cuda()) + + # check if comps are ordered from 0 to N + assert len(cc_cpu.unique()) == cc_cpu.unique().max() + 1 + assert len(cc_gpu.unique()) == cc_gpu.unique().max() + 1 + # check if same number of comps found between cpu and gpu + assert len(cc_cpu.unique()) == len(cc_gpu.unique())