diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 332571d345..1eb8935141 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -69,6 +69,13 @@ Metrics .. autoclass:: SurfaceDistanceMetric :members: +`Surface dice` +-------------- +.. autofunction:: compute_surface_dice + +.. autoclass:: SurfaceDiceMetric + :members: + `Mean squared error` -------------------- .. autoclass:: MSEMetric diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index d18c20f7b2..53d11893ed 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -17,5 +17,6 @@ from .metric import Cumulative, CumulativeIterationMetric, IterationMetric, Metric from .regression import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric from .rocauc import ROCAUCMetric, compute_roc_auc +from .surface_dice import SurfaceDiceMetric, compute_surface_dice from .surface_distance import SurfaceDistanceMetric, compute_average_surface_distance from .utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py new file mode 100644 index 0000000000..5630af178d --- /dev/null +++ b/monai/metrics/surface_dice.py @@ -0,0 +1,236 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Union + +import numpy as np +import torch + +from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background +from monai.utils import MetricReduction, convert_data_type + +from .metric import CumulativeIterationMetric + + +class SurfaceDiceMetric(CumulativeIterationMetric): + """ + Computes the Normalized Surface Distance (NSD) for each batch sample and class of + predicted segmentations `y_pred` and corresponding reference segmentations `y` according to equation :eq:`nsd`. + This implementation supports 2D images. For 3D images, please refer to DeepMind's implementation + https://github.com/deepmind/surface-distance. + + The class- and batch sample-wise NSD values can be aggregated with the function `aggregate`. + + Args: + class_thresholds: List of class-specific thresholds. + The thresholds relate to the acceptable amount of deviation in the segmentation boundary in pixels. + Each threshold needs to be a finite, non-negative number. + include_background: Whether to skip NSD computation on the first channel of the predicted output. + Defaults to ``False``. + distance_metric: The metric used to compute surface distances. + One of [``"euclidean"``, ``"chessboard"``, ``"taxicab"``]. + Defaults to ``"euclidean"``. + reduction: The mode to aggregate metrics. + One of [``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``, + ``"none"``]. + Defaults to ``"mean"``. + If ``"none"`` is chosen, no aggregation will be performed. + The aggregation will ignore nan values. + get_not_nans: whether to return the `not_nans` count. + Defaults to ``False``. + `not_nans` is the number of batch samples for which not all class-specific NSD values were nan values. + If set to ``True``, the function `aggregate` will return both the aggregated NSD and the `not_nans` count. + If set to ``False``, `aggregate` will only return the aggregated NSD. + """ + + def __init__( + self, + class_thresholds: List[float], + include_background: bool = False, + distance_metric: str = "euclidean", + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + get_not_nans: bool = False, + ) -> None: + super().__init__() + self.class_thresholds = class_thresholds + self.include_background = include_background + self.distance_metric = distance_metric + self.reduction = reduction + self.get_not_nans = get_not_nans + + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore + r""" + Args: + y_pred: Predicted segmentation, typically segmentation model output. + It must be a one-hot encoded, batch-first tensor [B,C,H,W]. + y: Reference segmentation. + It must be a one-hot encoded, batch-first tensor [B,C,H,W]. + + Returns: + Pytorch Tensor of shape [B,C], containing the NSD values :math:`\operatorname {NSD}_{b,c}` for each batch + index :math:`b` and class :math:`c`. + """ + return compute_surface_dice( + y_pred=y_pred, + y=y, + class_thresholds=self.class_thresholds, + include_background=self.include_background, + distance_metric=self.distance_metric, + ) + + def aggregate(self): + r""" + Aggregates the output of `_compute_tensor`. + + Returns: + If `get_not_nans` is set to ``True``, this function returns the aggregated NSD and the `not_nans` count. + If `get_not_nans` is set to ``False``, this function returns only the aggregated NSD. + """ + data = self.get_buffer() + if not isinstance(data, torch.Tensor): + raise ValueError("the data to aggregate must be PyTorch Tensor.") + + # do metric reduction + f, not_nans = do_metric_reduction(data, self.reduction) + return (f, not_nans) if self.get_not_nans else f + + +def compute_surface_dice( + y_pred: torch.Tensor, + y: torch.Tensor, + class_thresholds: List[float], + include_background: bool = False, + distance_metric: str = "euclidean", +): + r""" + This function computes the (Normalized) Surface Dice (NSD) between the two tensors `y_pred` (referred to as + :math:`\hat{Y}`) and `y` (referred to as :math:`Y`). This metric determines which fraction of a segmentation + boundary is correctly predicted. A boundary element is considered correctly predicted if the closest distance to the + reference boundary is smaller than or equal to the specified threshold related to the acceptable amount of deviation in + pixels. The NSD is bounded between 0 and 1. + + This implementation supports multi-class tasks with an individual threshold :math:`\tau_c` for each class :math:`c`. + The class-specific NSD for batch index :math:`b`, :math:`\operatorname {NSD}_{b,c}`, is computed using the function: + + .. math:: + \operatorname {NSD}_{b,c} \left(Y_{b,c}, \hat{Y}_{b,c}\right) = \frac{\left|\mathcal{D}_{Y_{b,c}}^{'}\right| + + \left| \mathcal{D}_{\hat{Y}_{b,c}}^{'} \right|}{\left|\mathcal{D}_{Y_{b,c}}\right| + + \left|\mathcal{D}_{\hat{Y}_{b,c}}\right|} + :label: nsd + + with :math:`\mathcal{D}_{Y_{b,c}}` and :math:`\mathcal{D}_{\hat{Y}_{b,c}}` being two sets of nearest-neighbor + distances. :math:`\mathcal{D}_{Y_{b,c}}` is computed from the predicted segmentation boundary towards the reference segmentation + boundary and vice-versa for :math:`\mathcal{D}_{\hat{Y}_{b,c}}`. :math:`\mathcal{D}_{Y_{b,c}}^{'}` and + :math:`\mathcal{D}_{\hat{Y}_{b,c}}^{'}` refer to the subsets of distances that are smaller or equal to the + acceptable distance :math:`\tau_c`: + + .. math:: + \mathcal{D}_{Y_{b,c}}^{'} = \{ d \in \mathcal{D}_{Y_{b,c}} \, | \, d \leq \tau_c \}. + + + In the case of a class neither being present in the predicted segmentation, nor in the reference segmentation, a nan value + will be returned for this class. In the case of a class being present in only one of predicted segmentation or + reference segmentation, the class NSD will be 0. + + This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D images. + Be aware that the computation of boundaries is different from DeepMind's implementation + https://github.com/deepmind/surface-distance. In this implementation, the length of a segmentation boundary is + interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary + depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430). + + Args: + y_pred: Predicted segmentation, typically segmentation model output. + It must be a one-hot encoded, batch-first tensor [B,C,H,W]. + y: Reference segmentation. + It must be a one-hot encoded, batch-first tensor [B,C,H,W]. + class_thresholds: List of class-specific thresholds. + The thresholds relate to the acceptable amount of deviation in the segmentation boundary in pixels. + Each threshold needs to be a finite, non-negative number. + include_background: Whether to skip the surface dice computation on the first channel of + the predicted output. Defaults to ``False``. + distance_metric: The metric used to compute surface distances. + One of [``"euclidean"``, ``"chessboard"``, ``"taxicab"``]. + Defaults to ``"euclidean"``. + + Raises: + ValueError: If `y_pred` and/or `y` are not PyTorch tensors. + ValueError: If `y_pred` and/or `y` do not have four dimensions. + ValueError: If `y_pred` and/or `y` have different shapes. + ValueError: If `y_pred` and/or `y` are not one-hot encoded + ValueError: If the number of channels of `y_pred` and/or `y` is different from the number of class thresholds. + ValueError: If any class threshold is not finite. + ValueError: If any class threshold is negative. + + Returns: + Pytorch Tensor of shape [B,C], containing the NSD values :math:`\operatorname {NSD}_{b,c}` for each batch index + :math:`b` and class :math:`c`. + """ + + if not include_background: + y_pred, y = ignore_background(y_pred=y_pred, y=y) + + if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor): + raise ValueError("y_pred and y must be PyTorch Tensor.") + + if y_pred.ndimension() != 4 or y.ndimension() != 4: + raise ValueError("y_pred and y should have four dimensions: [B,C,H,W].") + + if y_pred.shape != y.shape: + raise ValueError( + f"y_pred and y should have same shape, but instead, shapes are {y_pred.shape} (y_pred) and {y.shape} (y)." + ) + + if not torch.all(y_pred.byte() == y_pred) or not torch.all(y.byte() == y): + raise ValueError("y_pred and y should be binarized tensors (e.g. torch.int64).") + if torch.any(y_pred > 1) or torch.any(y > 1): + raise ValueError("y_pred and y should be one-hot encoded.") + + y = y.float() + y_pred = y_pred.float() + + batch_size, n_class = y_pred.shape[:2] + + if n_class != len(class_thresholds): + raise ValueError( + f"number of classes ({n_class}) does not match number of class thresholds ({len(class_thresholds)})." + ) + + if any(~np.isfinite(class_thresholds)): + raise ValueError("All class thresholds need to be finite.") + + if any(np.array(class_thresholds) < 0): + raise ValueError("All class thresholds need to be >= 0.") + + nsd = np.empty((batch_size, n_class)) + + for b, c in np.ndindex(batch_size, n_class): + (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=False) + if not np.any(edges_gt): + warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.") + if not np.any(edges_pred): + warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") + + distances_pred_gt = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric) + distances_gt_pred = get_surface_distance(edges_gt, edges_pred, distance_metric=distance_metric) + + boundary_complete = len(distances_pred_gt) + len(distances_gt_pred) + boundary_correct = np.sum(distances_pred_gt <= class_thresholds[c]) + np.sum( + distances_gt_pred <= class_thresholds[c] + ) + + if boundary_complete == 0: + # the class is neither present in the prediction, nor in the reference segmentation + nsd[b, c] = np.nan + else: + nsd[b, c] = boundary_correct / boundary_complete + + return convert_data_type(nsd, torch.Tensor)[0] diff --git a/tests/min_tests.py b/tests/min_tests.py index 988a703e5a..66b6c9ff3d 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -145,6 +145,7 @@ def run_testsuit(): "test_spacingd", "test_splitdimd", "test_surface_distance", + "test_surface_dice", "test_testtimeaugmentation", "test_torchvision", "test_torchvisiond", diff --git a/tests/test_surface_dice.py b/tests/test_surface_dice.py new file mode 100644 index 0000000000..5252adafce --- /dev/null +++ b/tests/test_surface_dice.py @@ -0,0 +1,292 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +import torch.nn.functional as F + +from monai.metrics.surface_dice import SurfaceDiceMetric + + +class TestAllSurfaceDiceMetrics(unittest.TestCase): + def test_tolerance_euclidean_distance(self): + batch_size = 2 + n_class = 2 + predictions = torch.zeros((batch_size, 480, 640), dtype=torch.int64) + labels = torch.zeros((batch_size, 480, 640), dtype=torch.int64) + predictions[0, :, 50:] = 1 + labels[0, :, 60:] = 1 # 10 px shift + predictions_hot = F.one_hot(predictions, num_classes=n_class).permute(0, 3, 1, 2) + labels_hot = F.one_hot(labels, num_classes=n_class).permute(0, 3, 1, 2) + + sd0 = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True) + res0 = sd0(predictions_hot, labels_hot) + agg0 = sd0.aggregate() # aggregation: nanmean across image then nanmean across batch + sd0_nans = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True, get_not_nans=True) + res0_nans = sd0_nans(predictions_hot, labels_hot) + agg0_nans, not_nans = sd0_nans.aggregate() + + np.testing.assert_array_equal(res0, res0_nans) + np.testing.assert_array_equal(agg0, agg0_nans) + + res1 = SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, labels_hot) + res9 = SurfaceDiceMetric(class_thresholds=[9, 9], include_background=True)(predictions_hot, labels_hot) + res10 = SurfaceDiceMetric(class_thresholds=[10, 10], include_background=True)(predictions_hot, labels_hot) + res11 = SurfaceDiceMetric(class_thresholds=[11, 11], include_background=True)(predictions_hot, labels_hot) + + for res in [res0, res9, res10, res11]: + assert res.shape == torch.Size([2, 2]) + + assert res0[0, 0] < res1[0, 0] < res9[0, 0] < res10[0, 0] + assert res0[0, 1] < res1[0, 1] < res9[0, 1] < res10[0, 1] + np.testing.assert_array_equal(res10, res11) + + expected_res0 = np.zeros((batch_size, n_class)) + expected_res0[0, 1] = 1 - (478 + 480 + 9 * 2) / (480 * 4 + 588 * 2 + 578 * 2) + expected_res0[0, 0] = 1 - (478 + 480 + 9 * 2) / (480 * 4 + 48 * 2 + 58 * 2) + expected_res0[1, 0] = 1 + expected_res0[1, 1] = np.nan + for b, c in np.ndindex(batch_size, n_class): + np.testing.assert_allclose(expected_res0[b, c], res0[b, c]) + np.testing.assert_array_equal(agg0, np.nanmean(np.nanmean(expected_res0, axis=1), axis=0)) + np.testing.assert_equal(not_nans, torch.tensor(2)) + + def test_tolerance_all_distances(self): + batch_size = 1 + n_class = 2 + predictions = torch.zeros((batch_size, 10, 10), dtype=torch.int64) + labels = torch.zeros((batch_size, 10, 10), dtype=torch.int64) + predictions[0, 1:4, 1] = 1 + """ + [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]] + """ + labels[0, 5:8, 6] = 1 + """ + [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]] + """ + predictions_hot = F.one_hot(predictions, num_classes=n_class).permute(0, 3, 1, 2) + labels_hot = F.one_hot(labels, num_classes=n_class).permute(0, 3, 1, 2) + + # Euclidean distance: + # background: + # 36 boundary pixels have 0 distances; non-zero distances: + # distances gt_pred: [3, np.sqrt(9+4), 2, 3, 2, 2, 2, 1] + # distances pred_gt: [1, 2, 2, 1] + # class 1: + # distances gt_pred: [sqrt(25+4), sqrt(25+9), sqrt(25+16)] = [5.38516481, 5.83095189, 6.40312424] + # distances pred_gt: [sqrt(25+16), sqrt(25+9), sqrt(25+4)] = [6.40312424, 5.83095189, 5.38516481] + + res = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True)(predictions_hot, labels_hot) + expected_res = [[1 - (8 + 4) / (36 * 2 + 8 + 4), 0]] + np.testing.assert_array_almost_equal(res, expected_res) + + res = SurfaceDiceMetric(class_thresholds=[2.8, 5.5], include_background=True)(predictions_hot, labels_hot) + expected_res = [[1 - 3 / (36 * 2 + 8 + 4), 1 - (2 + 2) / (3 + 3)]] + np.testing.assert_array_almost_equal(res, expected_res) + + res = SurfaceDiceMetric(class_thresholds=[3, 6], include_background=True)(predictions_hot, labels_hot) + expected_res = [[1 - 1 / (36 * 2 + 8 + 4), 1 - 2 / (3 + 3)]] + np.testing.assert_array_almost_equal(res, expected_res) + + # Chessboard distance: + # background: + # 36 boundary pixels have 0 distances; non-zero distances: + # distances gt_pred: [max(3,0), max(3,2), max(2,0), max(3,3), max(2,0), max(0,2), max(2,0), max(0,1)] = + # [3, 3, 2, 3, 2, 2, 2, 1] + # distances pred_gt: [max(0,1), max(2,0), max(2,0), max(1,0)] = [1, 2, 2, 1] + # class 1: + # distances gt_pred: [max(5,2), max(5,3), max(5,4)] = [5, 5, 5] + # distances pred_gt: [max(5,4), max(5,3), max(5,2)] = [5, 5, 5] + + res = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True, distance_metric="chessboard")( + predictions_hot, labels_hot + ) + expected_res = [[1 - (8 + 4) / (36 * 2 + 8 + 4), 0]] + np.testing.assert_array_almost_equal(res, expected_res) + + res = SurfaceDiceMetric(class_thresholds=[1, 4.999], include_background=True, distance_metric="chessboard")( + predictions_hot, labels_hot + ) + expected_res = [[1 - (7 + 2) / (36 * 2 + 8 + 4), 0]] + np.testing.assert_array_almost_equal(res, expected_res) + + res = SurfaceDiceMetric(class_thresholds=[2, 5], include_background=True, distance_metric="chessboard")( + predictions_hot, labels_hot + ) + expected_res = [[1 - 3 / (36 * 2 + 8 + 4), 1]] + np.testing.assert_array_almost_equal(res, expected_res) + + # Taxicab distance (= Manhattan distance): + # background: + # 36 boundary pixels have 0 distances; non-zero distances: + # distances gt_pred: [3+0, 4+0, 2+0, 0+3, 2+0, 0+2, 2+0, 0+1] = [3, 4, 2, 3, 2, 2, 2, 1] + # distances pred_gt: [0+1, 2+0, 2+0, 1+0] = [1, 2, 2, 1] + # class 1: + # distances gt_pred: [5+2, 5+3, 5+4] = [7, 8, 9] + # distances pred_gt: [5+4, 5+3, 5+2] = [9, 8, 7] + + res = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True, distance_metric="taxicab")( + predictions_hot, labels_hot + ) + expected_res = [[1 - (8 + 4) / (36 * 2 + 8 + 4), 0]] + np.testing.assert_array_almost_equal(res, expected_res) + + res = SurfaceDiceMetric(class_thresholds=[1, 7], include_background=True, distance_metric="taxicab")( + predictions_hot, labels_hot + ) + expected_res = [[1 - (7 + 2) / (36 * 2 + 8 + 4), 1 - (2 + 2) / (3 + 3)]] + np.testing.assert_array_almost_equal(res, expected_res) + + res = SurfaceDiceMetric(class_thresholds=[3, 9], include_background=True, distance_metric="taxicab")( + predictions_hot, labels_hot + ) + expected_res = [[1 - 1 / (36 * 2 + 8 + 4), 1]] + np.testing.assert_array_almost_equal(res, expected_res) + + def test_asserts(self): + batch_size = 1 + n_class = 2 + predictions = torch.zeros((batch_size, 80, 80), dtype=torch.int64) + labels = torch.zeros((batch_size, 80, 80), dtype=torch.int64) + predictions[0, 10:20, 10:20] = 1 + labels[0, 20:30, 20:30] = 1 + predictions_hot = F.one_hot(predictions, num_classes=n_class).permute(0, 3, 1, 2) + labels_hot = F.one_hot(labels, num_classes=n_class).permute(0, 3, 1, 2) + + # no torch tensor + with self.assertRaises(ValueError) as context: + SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot.numpy(), labels_hot) + self.assertEqual( + "y_pred or y must be a list/tuple of `channel-first` Tensors or a `batch-first` Tensor.", + str(context.exception), + ) + with self.assertRaises(ValueError) as context: + SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, labels_hot.numpy()) + self.assertEqual("y_pred and y must be PyTorch Tensor.", str(context.exception)) + + # wrong dimensions + with self.assertRaises(ValueError) as context: + SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions, labels_hot) + self.assertEqual("y_pred and y should have four dimensions: [B,C,H,W].", str(context.exception)) + with self.assertRaises(ValueError) as context: + SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, labels) + self.assertEqual("y_pred and y should have four dimensions: [B,C,H,W].", str(context.exception)) + + # mismatch of shape of input tensors + input_bad_shape = torch.clone(predictions_hot) + input_bad_shape = input_bad_shape[:, :, :, :50] + + with self.assertRaises(ValueError) as context: + SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, input_bad_shape) + self.assertEqual( + "y_pred and y should have same shape, but instead, shapes are torch.Size([1, 2, 80, 80]) (y_pred) and " + "torch.Size([1, 2, 80, 50]) (y).", + str(context.exception), + ) + + # input tensors not one-hot encoded + predictions_no_hot = torch.clone(predictions_hot) + predictions_no_hot[0, :, 0, 0] = torch.tensor([2, 0]) + + with self.assertRaises(ValueError) as context: + SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_no_hot, predictions_hot) + self.assertEqual("y_pred and y should be one-hot encoded.", str(context.exception)) + with self.assertRaises(ValueError) as context: + SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, predictions_no_hot) + self.assertEqual("y_pred and y should be one-hot encoded.", str(context.exception)) + + predictions_no_hot = predictions_no_hot.float() + predictions_no_hot[0, :, 0, 0] = torch.tensor([0.5, 0]) + with self.assertRaises(ValueError) as context: + SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_no_hot, predictions_hot) + self.assertEqual("y_pred and y should be binarized tensors (e.g. torch.int64).", str(context.exception)) + with self.assertRaises(ValueError) as context: + SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, predictions_no_hot) + self.assertEqual("y_pred and y should be binarized tensors (e.g. torch.int64).", str(context.exception)) + + # wrong number of class thresholds + with self.assertRaises(ValueError) as context: + SurfaceDiceMetric(class_thresholds=[1, 1, 1], include_background=True)(predictions_hot, labels_hot) + self.assertEqual("number of classes (2) does not match number of class thresholds (3).", str(context.exception)) + + # inf and nan values in class thresholds + with self.assertRaises(ValueError) as context: + SurfaceDiceMetric(class_thresholds=[np.inf, 1], include_background=True)(predictions_hot, labels_hot) + self.assertEqual("All class thresholds need to be finite.", str(context.exception)) + + with self.assertRaises(ValueError) as context: + SurfaceDiceMetric(class_thresholds=[np.nan, 1], include_background=True)(predictions_hot, labels_hot) + self.assertEqual("All class thresholds need to be finite.", str(context.exception)) + + # negative values in class thresholds: + with self.assertRaises(ValueError) as context: + SurfaceDiceMetric(class_thresholds=[-0.22, 1], include_background=True)(predictions_hot, labels_hot) + self.assertEqual("All class thresholds need to be >= 0.", str(context.exception)) + + def test_not_predicted_not_present(self): + # class is present in labels, but not in prediction -> nsd of 0 should be yielded for that class; class is + # neither present on labels, nor prediction -> nan should be yielded + batch_size = 1 + n_class = 4 + predictions = torch.zeros((batch_size, 80, 80), dtype=torch.int64) + labels = torch.zeros((batch_size, 80, 80), dtype=torch.int64) + predictions[0, 10:20, 10:20] = 1 + labels[0, 10:20, 10:20] = 2 + predictions_hot = F.one_hot(predictions, num_classes=n_class).permute(0, 3, 1, 2) + labels_hot = F.one_hot(labels, num_classes=n_class).permute(0, 3, 1, 2) + + # with and without background class + sur_metric_bgr = SurfaceDiceMetric(class_thresholds=[1, 1, 1, 1], include_background=True) + sur_metric = SurfaceDiceMetric(class_thresholds=[1, 1, 1], include_background=False) + + # test per-class results + res_bgr_classes = sur_metric_bgr(predictions_hot, labels_hot) + np.testing.assert_array_equal(res_bgr_classes, [[1, 0, 0, np.nan]]) + res_classes = sur_metric(predictions_hot, labels_hot) + np.testing.assert_array_equal(res_classes, [[0, 0, np.nan]]) + + # test aggregation + res_bgr = sur_metric_bgr.aggregate() + np.testing.assert_equal(res_bgr, torch.tensor([1 / 3], dtype=torch.float64)) + res = sur_metric.aggregate() + np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float64)) + + predictions_empty = torch.zeros((2, 3, 1, 1)) + sur_metric_nans = SurfaceDiceMetric(class_thresholds=[1, 1, 1], include_background=True, get_not_nans=True) + res_classes = sur_metric_nans(predictions_empty, predictions_empty) + res, not_nans = sur_metric_nans.aggregate() + np.testing.assert_array_equal(res_classes, [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan]]) + np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float64)) + np.testing.assert_equal(not_nans, torch.tensor([0], dtype=torch.float64)) + + +if __name__ == "__main__": + unittest.main()