diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index 98160b41bad..5bee339c40e 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -57,6 +57,11 @@ BigEarthNet .. autoclass:: BigEarthNetDataModule +CaBuAr +^^^^^^ + +.. autoclass:: CaBuArDataModule + ChaBuD ^^^^^^ diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 85f2f6e4587..398b715e818 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -217,6 +217,11 @@ BioMassters .. autoclass:: BioMassters +CaBuAr +^^^^^^ + +.. autoclass:: CaBuAr + ChaBuD ^^^^^^ diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv index 746840e1221..508ea2c74b9 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/non_geo_datasets.csv @@ -3,6 +3,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `Benin Cashew Plantations`_,S,Airbus Pléiades,"CC-BY-4.0",70,6,"1,122x1,186",10,MSI `BigEarthNet`_,C,Sentinel-1/2,"CDLA-Permissive-1.0","590,326",19--43,120x120,10,"SAR, MSI" `BioMassters`_,R,Sentinel-1/2 and Lidar,"CC-BY-4.0",,,256x256, 10, "SAR, MSI" +`CaBuAr`_,CD,Sentinel-2,"OpenRAIL",424,2,512x512,20,MSI `ChaBuD`_,CD,Sentinel-2,"OpenRAIL",356,2,512x512,10,MSI `Cloud Cover Detection`_,S,Sentinel-2,"CC-BY-4.0","22,728",2,512x512,10,MSI `COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","AGPL-3.0-only","388,435",2,256x256,0.15,RGB diff --git a/tests/conf/cabuar.yaml b/tests/conf/cabuar.yaml new file mode 100644 index 00000000000..42705a94542 --- /dev/null +++ b/tests/conf/cabuar.yaml @@ -0,0 +1,16 @@ +model: + class_path: SemanticSegmentationTask + init_args: + loss: "ce" + model: "unet" + backbone: "resnet18" + in_channels: 24 + num_classes: 2 + num_filters: 1 + ignore_index: null +data: + class_path: CaBuArDataModule + init_args: + batch_size: 2 + dict_kwargs: + root: "tests/data/cabuar" diff --git a/tests/data/cabuar/512x512.hdf5 b/tests/data/cabuar/512x512.hdf5 new file mode 100644 index 00000000000..5d8f16529bb Binary files /dev/null and b/tests/data/cabuar/512x512.hdf5 differ diff --git a/tests/data/cabuar/chabud_test.h5 b/tests/data/cabuar/chabud_test.h5 new file mode 100644 index 00000000000..5408b9d27fc Binary files /dev/null and b/tests/data/cabuar/chabud_test.h5 differ diff --git a/tests/data/cabuar/data.py b/tests/data/cabuar/data.py new file mode 100644 index 00000000000..9be447d816d --- /dev/null +++ b/tests/data/cabuar/data.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import random + +import h5py +import numpy as np + +# Sentinel-2 is 12-bit with range 0-4095 +SENTINEL2_MAX = 4096 + +NUM_CHANNELS = 12 +NUM_CLASSES = 2 +SIZE = 32 + +np.random.seed(0) +random.seed(0) + +filenames = ['512x512.hdf5', 'chabud_test.h5'] +fold_mapping = {'train': [1, 2, 3, 4], 'val': [0], 'test': ['chabud']} + +uris = [ + 'feb08801-64b1-4d11-a3fc-0efaad1f4274_0', + 'e4d4dbcb-dd92-40cf-a7fe-fda8dd35f367_1', + '9fc8c1f4-1858-47c3-953e-1dc8b179a', + '3a1358a2-6155-445a-a269-13bebd9741a8_0', + '2f8e659c-f457-4527-a57f-bffc3bbe0baa_0', + '299ee670-19b1-4a76-bef3-34fd55580711_1', + '05cfef86-3e27-42be-a0cb-a61fe2f89e40_0', + '0328d12a-4ad8-4504-8ac5-70089db10b4e_1', + '04800581-b540-4f9b-9df8-7ee433e83f46_0', + '108ae2a9-d7d6-42f7-b89a-90bb75c23ccb_0', + '29413474-04b8-4bb1-8b89-fd640023d4a6_0', + '43f2e60a-73b4-4f33-b99e-319d892fcab4_0', +] +folds = random.choices(fold_mapping['train'], k=4) + [0] * 4 + ['chabud'] * 4 +files = ['512x512.hdf5'] * 8 + ['chabud_test.h5'] * 4 + +# Remove old data +for filename in filenames: + if os.path.exists(filename): + os.remove(filename) + +# Create dataset file +data = np.random.randint( + SENTINEL2_MAX, size=(SIZE, SIZE, NUM_CHANNELS), dtype=np.uint16 +) +gt = np.random.randint(NUM_CLASSES, size=(SIZE, SIZE, 1), dtype=np.uint16) + +for filename, uri, fold in zip(files, uris, folds): + with h5py.File(filename, 'a') as f: + sample = f.create_group(uri) + sample.attrs.create( + name='fold', data=np.int64(fold) if fold != 'chabud' else fold + ) + sample.create_dataset + sample.create_dataset('pre_fire', data=data) + sample.create_dataset('post_fire', data=data) + sample.create_dataset('mask', data=gt) + +# Compute checksums +for filename in filenames: + with open(filename, 'rb') as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f'{filename} md5: {md5}') diff --git a/tests/datasets/test_cabuar.py b/tests/datasets/test_cabuar.py new file mode 100644 index 00000000000..967f43ee4d3 --- /dev/null +++ b/tests/datasets/test_cabuar.py @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from itertools import product +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch + +from torchgeo.datasets import CaBuAr, DatasetNotFoundError + +pytest.importorskip('h5py', minversion='3.6') + + +class TestCaBuAr: + @pytest.fixture( + params=product([CaBuAr.all_bands, CaBuAr.rgb_bands], ['train', 'val', 'test']) + ) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> CaBuAr: + data_dir = os.path.join('tests', 'data', 'cabuar') + urls = ( + os.path.join(data_dir, '512x512.hdf5'), + os.path.join(data_dir, 'chabud_test.h5'), + ) + monkeypatch.setattr(CaBuAr, 'urls', urls) + bands, split = request.param + root = tmp_path + transforms = nn.Identity() + return CaBuAr( + root=root, + split=split, + bands=bands, + transforms=transforms, + download=True, + checksum=True, + ) + + def test_getitem(self, dataset: CaBuAr) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + + # Image tests + assert x['image'].ndim == 3 + + if dataset.bands == CaBuAr.rgb_bands: + assert x['image'].shape[0] == 2 * 3 + elif dataset.bands == CaBuAr.all_bands: + assert x['image'].shape[0] == 2 * 12 + + # Mask tests: + assert x['mask'].ndim == 2 + + def test_len(self, dataset: CaBuAr) -> None: + assert len(dataset) == 4 + + def test_already_downloaded(self, dataset: CaBuAr) -> None: + CaBuAr(root=dataset.root, download=True) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + CaBuAr(tmp_path) + + def test_invalid_bands(self) -> None: + with pytest.raises(AssertionError): + CaBuAr(bands=('OK', 'BK')) + + def test_plot(self, dataset: CaBuAr) -> None: + dataset.plot(dataset[0], suptitle='Test') + plt.close() + + sample = dataset[0] + sample['prediction'] = sample['mask'].clone() + dataset.plot(sample, suptitle='prediction') + plt.close() + + def test_plot_rgb(self, dataset: CaBuAr) -> None: + dataset = CaBuAr(root=dataset.root, bands=('B02',)) + with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"): + dataset.plot(dataset[0], suptitle='Single Band') + + def test_invalid_split(self, dataset: CaBuAr) -> None: + with pytest.raises(AssertionError): + CaBuAr(dataset.root, split='foo') diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index ea4b0646521..00b293096dd 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -50,6 +50,7 @@ class TestSemanticSegmentationTask: 'name', [ 'agrifieldnet', + 'cabuar', 'chabud', 'chesapeake_cvpr_5', 'chesapeake_cvpr_7', @@ -82,7 +83,7 @@ def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: match name: - case 'chabud': + case 'chabud' | 'cabuar': pytest.importorskip('h5py', minversion='3.6') case 'landcoverai': sha256 = ( diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index a22a581dea3..498ad9fdb7a 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -5,6 +5,7 @@ from .agrifieldnet import AgriFieldNetDataModule from .bigearthnet import BigEarthNetDataModule +from .cabuar import CaBuArDataModule from .chabud import ChaBuDDataModule from .chesapeake import ChesapeakeCVPRDataModule from .cowc import COWCCountingDataModule @@ -64,6 +65,7 @@ 'SouthAfricaCropTypeDataModule', # NonGeoDataset 'BigEarthNetDataModule', + 'CaBuArDataModule', 'ChaBuDDataModule', 'COWCCountingDataModule', 'DeepGlobeLandCoverDataModule', diff --git a/torchgeo/datamodules/cabuar.py b/torchgeo/datamodules/cabuar.py new file mode 100644 index 00000000000..2ce459bceae --- /dev/null +++ b/torchgeo/datamodules/cabuar.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""CaBuAr datamodule.""" + +from typing import Any + +import torch +from einops import repeat + +from ..datasets import CaBuAr +from .geo import NonGeoDataModule + + +class CaBuArDataModule(NonGeoDataModule): + """LightningDataModule implementation for the CaBuAr dataset. + + Uses the train/val/test splits from the dataset + + .. versionadded:: 0.6 + """ + + # min/max values computed on train set using 2/98 percentiles + min = torch.tensor( + [0.0, 1.0, 73.0, 39.0, 46.0, 25.0, 26.0, 21.0, 17.0, 1.0, 20.0, 21.0] + ) + max = torch.tensor( + [ + 1926.0, + 2174.0, + 2527.0, + 2950.0, + 3237.0, + 3717.0, + 4087.0, + 4271.0, + 4290.0, + 4219.0, + 4568.0, + 3753.0, + ] + ) + + def __init__( + self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + ) -> None: + """Initialize a new CaBuArDataModule instance. + + Args: + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.CaBuAr`. + """ + bands = kwargs.get('bands', CaBuAr.all_bands) + band_indices = [CaBuAr.all_bands.index(b) for b in bands] + mins = self.min[band_indices] + maxs = self.max[band_indices] + + # Change detection, 2 images from different times + mins = repeat(mins, 'c -> (t c)', t=2) + maxs = repeat(maxs, 'c -> (t c)', t=2) + + self.mean = mins + self.std = maxs - mins + + super().__init__(CaBuAr, batch_size, num_workers, **kwargs) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index dd52286be9e..170aab9c68d 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -11,6 +11,7 @@ from .benin_cashews import BeninSmallHolderCashews from .bigearthnet import BigEarthNet from .biomassters import BioMassters +from .cabuar import CaBuAr from .cbf import CanadianBuildingFootprints from .cdl import CDL from .chabud import ChaBuD @@ -197,6 +198,7 @@ 'BeninSmallHolderCashews', 'BigEarthNet', 'BioMassters', + 'CaBuAr', 'ChaBuD', 'CloudCoverDetection', 'COWC', diff --git a/torchgeo/datasets/cabuar.py b/torchgeo/datasets/cabuar.py new file mode 100644 index 00000000000..69ca818a70e --- /dev/null +++ b/torchgeo/datasets/cabuar.py @@ -0,0 +1,303 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""CaBuAr dataset.""" + +import os +from collections.abc import Callable +from typing import ClassVar + +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.figure import Figure +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path, download_url, lazy_import, percentile_normalization + + +class CaBuAr(NonGeoDataset): + """CaBuAr dataset. + + `CaBuAr <https://huggingface.co/datasets/DarthReca/california_burned_areas>`__ + is a dataset for Change detection for Burned area Delineation and part of + the splits are used for the ChaBuD ECML-PKDD 2023 Discovery Challenge. + + Dataset features: + + * Sentinel-2 multispectral imagery + * binary masks of burned areas + * 12 multispectral bands + * 424 pairs of pre and post images with 20 m per pixel resolution (512x512 px) + + Dataset format: + + * single hdf5 dataset containing images and masks + + Dataset classes: + + 0. no change + 1. burned area + + If you use this dataset in your research, please cite the following paper: + + * https://doi.org/10.1109/MGRS.2023.3292467 + + .. note:: + + This dataset requires the following additional library to be installed: + + * `h5py <https://pypi.org/project/h5py/>`_ to load the dataset + + .. versionadded:: 0.6 + """ + + all_bands = ( + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B11', + 'B12', + ) + rgb_bands = ('B04', 'B03', 'B02') + folds: ClassVar[dict[str, list[object]]] = { + 'train': [1, 2, 3, 4], + 'val': [0], + 'test': ['chabud'], + } + urls = ( + 'https://huggingface.co/datasets/DarthReca/california_burned_areas/resolve/main/raw/patched/512x512.hdf5', + 'https://huggingface.co/datasets/DarthReca/california_burned_areas/resolve/main/raw/patched/chabud_test.h5', + ) + filenames = ('512x512.hdf5', 'chabud_test.h5') + md5s = ('15d78fb825f9a81dad600db828d22c08', 'a70bb7e4a2788657c2354c4c3d9296fe') + + def __init__( + self, + root: Path = 'data', + split: str = 'train', + bands: tuple[str, ...] = all_bands, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new CaBuAr dataset instance. + + Args: + root: root directory where dataset can be found + split: one of "train", "val", "test" + bands: the subset of bands to load + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If ``split`` or ``bands`` arguments are invalid. + DatasetNotFoundError: If dataset is not found and *download* is False. + DependencyNotFoundError: If h5py is not installed. + """ + lazy_import('h5py') + + assert split in self.folds + assert set(bands) <= set(self.all_bands) + + # Set the file index based on the split + file_index = 1 if split == 'test' else 0 + + self.root = root + self.split = split + self.bands = bands + self.transforms = transforms + self.download = download + self.checksum = checksum + self.filepath = os.path.join(root, self.filenames[file_index]) + self.band_indices = [self.all_bands.index(b) for b in bands] + + self._verify() + + self.uuids = self._load_uuids() + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + sample containing image and mask + """ + image = self._load_image(index) + mask = self._load_target(index) + + sample = {'image': image, 'mask': mask} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.uuids) + + def _load_uuids(self) -> list[str]: + """Return the image uuids for the given split. + + Returns: + the image uuids + """ + h5py = lazy_import('h5py') + uuids = [] + with h5py.File(self.filepath, 'r') as f: + for k, v in f.items(): + if v.attrs['fold'] in self.folds[self.split] and 'pre_fire' in v.keys(): + uuids.append(k) + return sorted(uuids) + + def _load_image(self, index: int) -> Tensor: + """Load a single image. + + Args: + index: index to return + + Returns: + the image + """ + h5py = lazy_import('h5py') + uuid = self.uuids[index] + with h5py.File(self.filepath, 'r') as f: + pre_array = f[uuid]['pre_fire'][:] + post_array = f[uuid]['post_fire'][:] + + # index specified bands and concatenate + pre_array = pre_array[..., self.band_indices] + post_array = post_array[..., self.band_indices] + array = np.concatenate([pre_array, post_array], axis=-1).astype(np.float32) + + tensor = torch.from_numpy(array) + # Convert from HxWxC to CxHxW + tensor = tensor.permute((2, 0, 1)) + return tensor + + def _load_target(self, index: int) -> Tensor: + """Load the target mask for a single image. + + Args: + index: index to return + + Returns: + the target mask + """ + h5py = lazy_import('h5py') + uuid = self.uuids[index] + with h5py.File(self.filepath, 'r') as f: + array = f[uuid]['mask'][:].astype(np.int32).squeeze(axis=-1) + + tensor = torch.from_numpy(array) + tensor = tensor.to(torch.long) + return tensor + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the files already exist + exists = [] + for filename in self.filenames: + filepath = os.path.join(self.root, filename) + exists.append(os.path.exists(filepath)) + + if all(exists): + return + + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset.""" + for url, filename, md5 in zip(self.urls, self.filenames, self.md5s): + filepath = os.path.join(self.root, filename) + if not os.path.exists(filepath): + download_url( + url, + self.root, + filename=filename, + md5=md5 if self.checksum else None, + ) + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional suptitle to use for figure + + Returns: + a matplotlib Figure with the rendered sample + """ + rgb_indices = [] + for band in self.rgb_bands: + if band in self.bands: + rgb_indices.append(self.bands.index(band)) + else: + raise ValueError("Dataset doesn't contain some of the RGB bands") + + mask = sample['mask'].numpy() + image_pre = sample['image'][: len(self.bands)][rgb_indices].numpy() + image_post = sample['image'][len(self.bands) :][rgb_indices].numpy() + image_pre = percentile_normalization(image_pre) + image_post = percentile_normalization(image_post) + + ncols = 3 + + showing_predictions = 'prediction' in sample + if showing_predictions: + prediction = sample['prediction'] + ncols += 1 + + fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5)) + + axs[0].imshow(np.transpose(image_pre, (1, 2, 0))) + axs[0].axis('off') + axs[1].imshow(np.transpose(image_post, (1, 2, 0))) + axs[1].axis('off') + axs[2].imshow(mask) + axs[2].axis('off') + + if showing_predictions: + axs[3].imshow(prediction) + axs[3].axis('off') + + if show_titles: + axs[0].set_title('Image Pre') + axs[1].set_title('Image Post') + axs[2].set_title('Mask') + if showing_predictions: + axs[3].set_title('Prediction') + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig