diff --git a/CHANGELOG.md b/CHANGELOG.md index 37e3b5d4963..fd629112ab8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `MetricInputTransformer` wrapper ([#2392](https://github.com/Lightning-AI/torchmetrics/pull/2392)) +- Added `input_format` argument to segmentation metrics ([#2572](https://github.com/Lightning-AI/torchmetrics/pull/2572)) + + ### Changed - diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index 6b740bcea53..04f28584b10 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -25,6 +25,7 @@ def _generalized_dice_validate_args( include_background: bool, per_class: bool, weight_type: Literal["square", "simple", "linear"], + input_format: Literal["one-hot", "index"], ) -> None: """Validate the arguments of the metric.""" if num_classes <= 0: @@ -37,6 +38,8 @@ def _generalized_dice_validate_args( raise ValueError( f"Expected argument `weight_type` to be one of 'square', 'simple', 'linear', but got {weight_type}." ) + if input_format not in ["one-hot", "index"]: + raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.") def _generalized_dice_update( @@ -45,15 +48,15 @@ def _generalized_dice_update( num_classes: int, include_background: bool, weight_type: Literal["square", "simple", "linear"] = "square", + input_format: Literal["one-hot", "index"] = "one-hot", ) -> Tensor: """Update the state with the current prediction and target.""" _check_same_shape(preds, target) if preds.ndim < 3: raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.") - if (preds.bool() != preds).any(): # preds is an index tensor + if input_format == "index": preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) - if (target.bool() != target).any(): # target is an index tensor target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) if not include_background: @@ -104,6 +107,7 @@ def generalized_dice_score( include_background: bool = True, per_class: bool = False, weight_type: Literal["square", "simple", "linear"] = "square", + input_format: Literal["one-hot", "index"] = "one-hot", ) -> Tensor: """Compute the Generalized Dice Score for semantic segmentation. @@ -114,6 +118,8 @@ def generalized_dice_score( include_background: Whether to include the background class in the computation per_class: Whether to compute the IoU for each class separately, else average over all classes weight_type: Type of weight factor to apply to the classes. One of ``"square"``, ``"simple"``, or ``"linear"`` + input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors + or ``"index"`` for index tensors Returns: The Generalized Dice Score @@ -133,6 +139,8 @@ def generalized_dice_score( [0.4715, 0.4925, 0.4797, 0.5267, 0.4788]]) """ - _generalized_dice_validate_args(num_classes, include_background, per_class, weight_type) - numerator, denominator = _generalized_dice_update(preds, target, num_classes, include_background, weight_type) + _generalized_dice_validate_args(num_classes, include_background, per_class, weight_type, input_format) + numerator, denominator = _generalized_dice_update( + preds, target, num_classes, include_background, weight_type, input_format + ) return _generalized_dice_compute(numerator, denominator, per_class) diff --git a/src/torchmetrics/functional/segmentation/mean_iou.py b/src/torchmetrics/functional/segmentation/mean_iou.py index 0a4e24da6e1..278257d04b1 100644 --- a/src/torchmetrics/functional/segmentation/mean_iou.py +++ b/src/torchmetrics/functional/segmentation/mean_iou.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.functional.segmentation.utils import _ignore_background from torchmetrics.utilities.checks import _check_same_shape @@ -25,6 +26,7 @@ def _mean_iou_validate_args( num_classes: int, include_background: bool, per_class: bool, + input_format: Literal["one-hot", "index"] = "one-hot", ) -> None: """Validate the arguments of the metric.""" if num_classes <= 0: @@ -33,6 +35,8 @@ def _mean_iou_validate_args( raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.") if not isinstance(per_class, bool): raise ValueError(f"Expected argument `per_class` must be a boolean, but got {per_class}.") + if input_format not in ["one-hot", "index"]: + raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.") def _mean_iou_update( @@ -40,13 +44,13 @@ def _mean_iou_update( target: Tensor, num_classes: int, include_background: bool = False, + input_format: Literal["one-hot", "index"] = "one-hot", ) -> Tuple[Tensor, Tensor]: """Update the intersection and union counts for the mean IoU computation.""" _check_same_shape(preds, target) - if (preds.bool() != preds).any(): # preds is an index tensor + if input_format == "index": preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) - if (target.bool() != target).any(): # target is an index tensor target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) if not include_background: @@ -76,6 +80,7 @@ def mean_iou( num_classes: int, include_background: bool = True, per_class: bool = False, + input_format: Literal["one-hot", "index"] = "one-hot", ) -> Tensor: """Calculates the mean Intersection over Union (mIoU) for semantic segmentation. @@ -85,6 +90,8 @@ def mean_iou( num_classes: Number of classes include_background: Whether to include the background class in the computation per_class: Whether to compute the IoU for each class separately, else average over all classes + input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors + or ``"index"`` for index tensors Returns: The mean IoU score @@ -104,6 +111,6 @@ def mean_iou( [0.3085, 0.3267, 0.3155, 0.3575, 0.3147]]) """ - _mean_iou_validate_args(num_classes, include_background, per_class) - intersection, union = _mean_iou_update(preds, target, num_classes, include_background) + _mean_iou_validate_args(num_classes, include_background, per_class, input_format) + intersection, union = _mean_iou_update(preds, target, num_classes, include_background, input_format) return _mean_iou_compute(intersection, union, per_class=per_class) diff --git a/src/torchmetrics/segmentation/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py index 646ba63fbcf..66f09437000 100644 --- a/src/torchmetrics/segmentation/generalized_dice.py +++ b/src/torchmetrics/segmentation/generalized_dice.py @@ -53,12 +53,12 @@ class GeneralizedDiceScore(Metric): - ``preds`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` - can be provided, where the integer values correspond to the class index. That format will be automatically - converted to a one-hot tensor. + can be provided, where the integer values correspond to the class index. The input type can be controlled + with the ``input_format`` argument. - ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` - can be provided, where the integer values correspond to the class index. That format will be automatically - converted to a one-hot tensor. + can be provided, where the integer values correspond to the class index. The input type can be controlled + with the ``input_format`` argument. As output to ``forward`` and ``compute`` the metric returns the following output: @@ -72,6 +72,8 @@ class GeneralizedDiceScore(Metric): per_class: Whether to compute the metric for each class separately. weight_type: The type of weight to apply to each class. Can be one of ``"square"``, ``"simple"``, or ``"linear"``. + input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors + or ``"index"`` for index tensors kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: @@ -83,6 +85,8 @@ class GeneralizedDiceScore(Metric): If ``per_class`` is not a boolean ValueError: If ``weight_type`` is not one of ``"square"``, ``"simple"``, or ``"linear"`` + ValueError: + If ``input_format`` is not one of ``"one-hot"`` or ``"index"`` Example: >>> import torch @@ -116,14 +120,16 @@ def __init__( include_background: bool = True, per_class: bool = False, weight_type: Literal["square", "simple", "linear"] = "square", + input_format: Literal["one-hot", "index"] = "one-hot", **kwargs: Any, ) -> None: super().__init__(**kwargs) - _generalized_dice_validate_args(num_classes, include_background, per_class, weight_type) + _generalized_dice_validate_args(num_classes, include_background, per_class, weight_type, input_format) self.num_classes = num_classes self.include_background = include_background self.per_class = per_class self.weight_type = weight_type + self.input_format = input_format num_classes = num_classes - 1 if not include_background else num_classes self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum") @@ -132,7 +138,7 @@ def __init__( def update(self, preds: Tensor, target: Tensor) -> None: """Update the state with new data.""" numerator, denominator = _generalized_dice_update( - preds, target, self.num_classes, self.include_background, self.weight_type + preds, target, self.num_classes, self.include_background, self.weight_type, self.input_format ) self.score += _generalized_dice_compute(numerator, denominator, self.per_class).sum(dim=0) self.samples += preds.shape[0] diff --git a/src/torchmetrics/segmentation/mean_iou.py b/src/torchmetrics/segmentation/mean_iou.py index f36f7fc3bc0..77d465ebd21 100644 --- a/src/torchmetrics/segmentation/mean_iou.py +++ b/src/torchmetrics/segmentation/mean_iou.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.functional.segmentation.mean_iou import _mean_iou_compute, _mean_iou_update, _mean_iou_validate_args from torchmetrics.metric import Metric @@ -36,12 +37,12 @@ class MeanIoU(Metric): - ``preds`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` - can be provided, where the integer values correspond to the class index. That format will be automatically - converted to a one-hot tensor. + can be provided, where the integer values correspond to the class index. The input type can be controlled + with the ``input_format`` argument. - ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` - can be provided, where the integer values correspond to the class index. That format will be automatically - converted to a one-hot tensor. + can be provided, where the integer values correspond to the class index. The input type can be controlled + with the ``input_format`` argument. As output to ``forward`` and ``compute`` the metric returns the following output: @@ -54,6 +55,8 @@ class MeanIoU(Metric): include_background: Whether to include the background class in the computation per_class: Whether to compute the IoU for each class separately. If set to ``False``, the metric will compute the mean IoU over all classes. + input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors + or ``"index"`` for index tensors kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: @@ -63,6 +66,8 @@ class MeanIoU(Metric): If ``include_background`` is not a boolean ValueError: If ``per_class`` is not a boolean + ValueError: + If ``input_format`` is not one of ``"one-hot"`` or ``"index"`` Example: >>> import torch @@ -95,20 +100,24 @@ def __init__( num_classes: int, include_background: bool = True, per_class: bool = False, + input_format: Literal["one-hot", "index"] = "one-hot", **kwargs: Any, ) -> None: super().__init__(**kwargs) - _mean_iou_validate_args(num_classes, include_background, per_class) + _mean_iou_validate_args(num_classes, include_background, per_class, input_format) self.num_classes = num_classes self.include_background = include_background self.per_class = per_class + self.input_format = input_format num_classes = num_classes - 1 if not include_background else num_classes self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="mean") def update(self, preds: Tensor, target: Tensor) -> None: """Update the state with the new data.""" - intersection, union = _mean_iou_update(preds, target, self.num_classes, self.include_background) + intersection, union = _mean_iou_update( + preds, target, self.num_classes, self.include_background, self.input_format + ) score = _mean_iou_compute(intersection, union, per_class=self.per_class) self.score += score.mean(0) if self.per_class else score.mean() diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index ed80e6fd6d7..a2bbab7b921 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -40,13 +40,13 @@ def _reference_generalized_dice( preds: torch.Tensor, target: torch.Tensor, + input_format: str, include_background: bool = True, reduce: bool = True, ): """Calculate reference metric for `MeanIoU`.""" - if (preds.bool() != preds).any(): # preds is an index tensor + if input_format == "index": preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) - if (target.bool() != target).any(): # target is an index tensor target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) val = compute_generalized_dice(preds, target, include_background=include_background) if reduce: @@ -55,11 +55,11 @@ def _reference_generalized_dice( @pytest.mark.parametrize( - "preds, target", + "preds, target, input_format", [ - (_inputs1.preds, _inputs1.target), - (_inputs2.preds, _inputs2.target), - (_inputs3.preds, _inputs3.target), + (_inputs1.preds, _inputs1.target, "one-hot"), + (_inputs2.preds, _inputs2.target, "one-hot"), + (_inputs3.preds, _inputs3.target, "index"), ], ) @pytest.mark.parametrize("include_background", [True, False]) @@ -67,23 +67,42 @@ class TestMeanIoU(MetricTester): """Test class for `MeanIoU` metric.""" @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_mean_iou_class(self, preds, target, include_background, ddp): + def test_mean_iou_class(self, preds, target, input_format, include_background, ddp): """Test class implementation of metric.""" self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=GeneralizedDiceScore, - reference_metric=partial(_reference_generalized_dice, include_background=include_background, reduce=True), - metric_args={"num_classes": NUM_CLASSES, "include_background": include_background}, + reference_metric=partial( + _reference_generalized_dice, + input_format=input_format, + include_background=include_background, + reduce=True, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "include_background": include_background, + "input_format": input_format, + }, ) - def test_mean_iou_functional(self, preds, target, include_background): + def test_mean_iou_functional(self, preds, target, input_format, include_background): """Test functional implementation of metric.""" self.run_functional_metric_test( preds=preds, target=target, metric_functional=generalized_dice_score, - reference_metric=partial(_reference_generalized_dice, include_background=include_background, reduce=False), - metric_args={"num_classes": NUM_CLASSES, "include_background": include_background, "per_class": False}, + reference_metric=partial( + _reference_generalized_dice, + input_format=input_format, + include_background=include_background, + reduce=False, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "include_background": include_background, + "per_class": False, + "input_format": input_format, + }, ) diff --git a/tests/unittests/segmentation/test_mean_iou.py b/tests/unittests/segmentation/test_mean_iou.py index 013b71572d6..68c2b060a9e 100644 --- a/tests/unittests/segmentation/test_mean_iou.py +++ b/tests/unittests/segmentation/test_mean_iou.py @@ -40,14 +40,14 @@ def _reference_mean_iou( preds: torch.Tensor, target: torch.Tensor, + input_format: str, include_background: bool = True, per_class: bool = True, reduce: bool = True, ): """Calculate reference metric for `MeanIoU`.""" - if (preds.bool() != preds).any(): # preds is an index tensor + if input_format == "index": preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) - if (target.bool() != target).any(): # target is an index tensor target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) val = compute_iou(preds, target, include_background=include_background) @@ -58,11 +58,11 @@ def _reference_mean_iou( @pytest.mark.parametrize( - "preds, target", + "preds, target, input_format", [ - (_inputs1.preds, _inputs1.target), - (_inputs2.preds, _inputs2.target), - (_inputs3.preds, _inputs3.target), + (_inputs1.preds, _inputs1.target, "one-hot"), + (_inputs2.preds, _inputs2.target, "one-hot"), + (_inputs3.preds, _inputs3.target, "index"), ], ) @pytest.mark.parametrize("include_background", [True, False]) @@ -73,7 +73,7 @@ class TestMeanIoU(MetricTester): @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) @pytest.mark.parametrize("per_class", [True, False]) - def test_mean_iou_class(self, preds, target, include_background, per_class, ddp): + def test_mean_iou_class(self, preds, target, input_format, include_background, per_class, ddp): """Test class implementation of metric.""" self.run_class_metric_test( ddp=ddp, @@ -81,17 +81,33 @@ def test_mean_iou_class(self, preds, target, include_background, per_class, ddp) target=target, metric_class=MeanIoU, reference_metric=partial( - _reference_mean_iou, include_background=include_background, per_class=per_class, reduce=True + _reference_mean_iou, + input_format=input_format, + include_background=include_background, + per_class=per_class, + reduce=True, ), - metric_args={"num_classes": NUM_CLASSES, "include_background": include_background, "per_class": per_class}, + metric_args={ + "num_classes": NUM_CLASSES, + "include_background": include_background, + "per_class": per_class, + "input_format": input_format, + }, ) - def test_mean_iou_functional(self, preds, target, include_background): + def test_mean_iou_functional(self, preds, target, input_format, include_background): """Test functional implementation of metric.""" self.run_functional_metric_test( preds=preds, target=target, metric_functional=mean_iou, - reference_metric=partial(_reference_mean_iou, include_background=include_background, reduce=False), - metric_args={"num_classes": NUM_CLASSES, "include_background": include_background, "per_class": True}, + reference_metric=partial( + _reference_mean_iou, input_format=input_format, include_background=include_background, reduce=False + ), + metric_args={ + "num_classes": NUM_CLASSES, + "include_background": include_background, + "per_class": True, + "input_format": input_format, + }, )