From 7be61544b0ef8832004dab181e8a0a9fb41f599d Mon Sep 17 00:00:00 2001 From: Kaustav Mukherjee Date: Sat, 26 Feb 2022 06:13:04 +0530 Subject: [PATCH] 3band indices (#414) * added triband normalized difference index to support indices e.g. Green-Blue NDVI added triband normalized difference index base class to support indices e.g. Green-Blue NDVI * Formatted * formatted the comments * formatted * Update indices.py * Update indices.py * formatted * formatted * formatted * formatted * formatted * frmtted * formatted the init * added test cases * removed probelematic code * formatted * formatted * formatted --- tests/transforms/test_indices.py | 8 ++++ torchgeo/transforms/__init__.py | 2 + torchgeo/transforms/indices.py | 63 ++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+) diff --git a/tests/transforms/test_indices.py b/tests/transforms/test_indices.py index a794e4fa510..7b6ccc1955a 100644 --- a/tests/transforms/test_indices.py +++ b/tests/transforms/test_indices.py @@ -18,6 +18,7 @@ AppendNDWI, AppendNormalizedDifferenceIndex, AppendSWI, + AppendTriBandNormalizedDifferenceIndex, ) @@ -64,6 +65,13 @@ def test_append_index_batch(batch: Dict[str, Tensor]) -> None: assert output["image"].shape == (b, c + 1, h, w) +def test_append_triband_index_batch(batch: Dict[str, Tensor]) -> None: + b, c, h, w = batch["image"].shape + tr = AppendTriBandNormalizedDifferenceIndex(index_a=0, index_b=0, index_c=0) + output = tr(batch) + assert output["image"].shape == (b, c + 1, h, w) + + @pytest.mark.parametrize( "index", [ diff --git a/torchgeo/transforms/__init__.py b/torchgeo/transforms/__init__.py index 61d4ff9510c..f5aa651e530 100644 --- a/torchgeo/transforms/__init__.py +++ b/torchgeo/transforms/__init__.py @@ -14,6 +14,7 @@ AppendNDWI, AppendNormalizedDifferenceIndex, AppendSWI, + AppendTriBandNormalizedDifferenceIndex, ) from .transforms import AugmentationSequential @@ -29,6 +30,7 @@ "AppendNDWI", "AppendSWI", "AugmentationSequential", + "AppendTriBandNormalizedDifferenceIndex", ) # https://stackoverflow.com/questions/40018681 diff --git a/torchgeo/transforms/indices.py b/torchgeo/transforms/indices.py index 63a45ae517f..aea72a080f2 100644 --- a/torchgeo/transforms/indices.py +++ b/torchgeo/transforms/indices.py @@ -301,3 +301,66 @@ def __init__(self, index_nir: int, index_vre1: int) -> None: index_vre1: index of the Red Edge band, B5 in Sentinel 2 imagery """ super().__init__(index_a=index_nir, index_b=index_vre1) + + +class AppendTriBandNormalizedDifferenceIndex(Module): + r"""Append normalized difference index involving 3 bands as channel to image tensor. + + Computes the following index: + + .. math:: + + \text{NDI} = \frac{A - {B + C}}{A + {B + C}} + + .. versionadded:: 0.3 + """ + + def __init__(self, index_a: int, index_b: int, index_c: int) -> None: + """Initialize a new transform instance. + + Args: + index_a: reference band channel index + index_b: difference band channel index of component 1 + index_c: difference band channel index of component 2 + """ + super().__init__() + self.dim = -3 + self.index_a = index_a + self.index_b = index_b + self.index_c = index_c + + def _compute_index(self, band_a: Tensor, band_b: Tensor, band_c: Tensor) -> Tensor: + """Compute tri-band normalized difference index. + + Args: + band_a: reference band tensor + band_b: difference band tensor component 1 + band_c: difference band tensor component 2 + + Returns: + the index + """ + return (band_a - (band_b + band_c)) / ((band_a + band_b + band_c) + _EPSILON) + + def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: + """Compute and append tri-band normalized difference index to image. + + Args: + sample: a sample or batch dict + + Returns: + the transformed sample + """ + if "image" in sample: + index = self._compute_index( + band_a=sample["image"][..., self.index_a, :, :], + band_b=sample["image"][..., self.index_b, :, :], + band_c=sample["image"][..., self.index_c, :, :], + ) + index = index.unsqueeze(self.dim) + + sample["image"] = torch.cat( # type: ignore[attr-defined] + [sample["image"], index], dim=self.dim + ) + + return sample