Skip to content

Commit

Permalink
3band indices (microsoft#414)
Browse files Browse the repository at this point in the history
* added triband normalized difference index to support indices e.g. Green-Blue NDVI

added triband normalized difference index base class to support indices e.g. Green-Blue NDVI

* Formatted

* formatted the comments

* formatted

* Update indices.py

* Update indices.py

* formatted

* formatted

* formatted

* formatted

* formatted

* frmtted

* formatted the init

* added test cases

* removed probelematic code

* formatted

* formatted

* formatted
  • Loading branch information
MATRIX4284 authored Feb 26, 2022
1 parent fd0fd71 commit 341483a
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tests/transforms/test_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AppendNDWI,
AppendNormalizedDifferenceIndex,
AppendSWI,
AppendTriBandNormalizedDifferenceIndex,
)


Expand Down Expand Up @@ -64,6 +65,13 @@ def test_append_index_batch(batch: Dict[str, Tensor]) -> None:
assert output["image"].shape == (b, c + 1, h, w)


def test_append_triband_index_batch(batch: Dict[str, Tensor]) -> None:
b, c, h, w = batch["image"].shape
tr = AppendTriBandNormalizedDifferenceIndex(index_a=0, index_b=0, index_c=0)
output = tr(batch)
assert output["image"].shape == (b, c + 1, h, w)


@pytest.mark.parametrize(
"index",
[
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
AppendNDWI,
AppendNormalizedDifferenceIndex,
AppendSWI,
AppendTriBandNormalizedDifferenceIndex,
)
from .transforms import AugmentationSequential

Expand All @@ -29,6 +30,7 @@
"AppendNDWI",
"AppendSWI",
"AugmentationSequential",
"AppendTriBandNormalizedDifferenceIndex",
)

# https://stackoverflow.com/questions/40018681
Expand Down
63 changes: 63 additions & 0 deletions torchgeo/transforms/indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,66 @@ def __init__(self, index_nir: int, index_vre1: int) -> None:
index_vre1: index of the Red Edge band, B5 in Sentinel 2 imagery
"""
super().__init__(index_a=index_nir, index_b=index_vre1)


class AppendTriBandNormalizedDifferenceIndex(Module):
r"""Append normalized difference index involving 3 bands as channel to image tensor.
Computes the following index:
.. math::
\text{NDI} = \frac{A - {B + C}}{A + {B + C}}
.. versionadded:: 0.3
"""

def __init__(self, index_a: int, index_b: int, index_c: int) -> None:
"""Initialize a new transform instance.
Args:
index_a: reference band channel index
index_b: difference band channel index of component 1
index_c: difference band channel index of component 2
"""
super().__init__()
self.dim = -3
self.index_a = index_a
self.index_b = index_b
self.index_c = index_c

def _compute_index(self, band_a: Tensor, band_b: Tensor, band_c: Tensor) -> Tensor:
"""Compute tri-band normalized difference index.
Args:
band_a: reference band tensor
band_b: difference band tensor component 1
band_c: difference band tensor component 2
Returns:
the index
"""
return (band_a - (band_b + band_c)) / ((band_a + band_b + band_c) + _EPSILON)

def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""Compute and append tri-band normalized difference index to image.
Args:
sample: a sample or batch dict
Returns:
the transformed sample
"""
if "image" in sample:
index = self._compute_index(
band_a=sample["image"][..., self.index_a, :, :],
band_b=sample["image"][..., self.index_b, :, :],
band_c=sample["image"][..., self.index_c, :, :],
)
index = index.unsqueeze(self.dim)

sample["image"] = torch.cat( # type: ignore[attr-defined]
[sample["image"], index], dim=self.dim
)

return sample

0 comments on commit 341483a

Please sign in to comment.