Skip to content

Commit

Permalink
Add CaBuAr dataset (#2235)
Browse files Browse the repository at this point in the history
* 🆕 Added CaBuAr dataset

* 🆕 Added CaBuAr datamodule

* 🔨 Added CaBuAr datamodule test

* 🔨 Corrected CaBuAr typing and datamodule test

* 🔨 updated test, corrected docs, minor fixes to dataset and datamodule

* 🔨 CaBuAr test fixes
  • Loading branch information
DarthReca authored Aug 28, 2024
1 parent 042b75e commit ccc314c
Show file tree
Hide file tree
Showing 13 changed files with 564 additions and 1 deletion.
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ BigEarthNet

.. autoclass:: BigEarthNetDataModule

CaBuAr
^^^^^^

.. autoclass:: CaBuArDataModule

ChaBuD
^^^^^^

Expand Down
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ BioMassters

.. autoclass:: BioMassters

CaBuAr
^^^^^^

.. autoclass:: CaBuAr

ChaBuD
^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`Benin Cashew Plantations`_,S,Airbus Pléiades,"CC-BY-4.0",70,6,"1,122x1,186",10,MSI
`BigEarthNet`_,C,Sentinel-1/2,"CDLA-Permissive-1.0","590,326",19--43,120x120,10,"SAR, MSI"
`BioMassters`_,R,Sentinel-1/2 and Lidar,"CC-BY-4.0",,,256x256, 10, "SAR, MSI"
`CaBuAr`_,CD,Sentinel-2,"OpenRAIL",424,2,512x512,20,MSI
`ChaBuD`_,CD,Sentinel-2,"OpenRAIL",356,2,512x512,10,MSI
`Cloud Cover Detection`_,S,Sentinel-2,"CC-BY-4.0","22,728",2,512x512,10,MSI
`COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","AGPL-3.0-only","388,435",2,256x256,0.15,RGB
Expand Down
16 changes: 16 additions & 0 deletions tests/conf/cabuar.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "unet"
backbone: "resnet18"
in_channels: 24
num_classes: 2
num_filters: 1
ignore_index: null
data:
class_path: CaBuArDataModule
init_args:
batch_size: 2
dict_kwargs:
root: "tests/data/cabuar"
Binary file added tests/data/cabuar/512x512.hdf5
Binary file not shown.
Binary file added tests/data/cabuar/chabud_test.h5
Binary file not shown.
69 changes: 69 additions & 0 deletions tests/data/cabuar/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import random

import h5py
import numpy as np

# Sentinel-2 is 12-bit with range 0-4095
SENTINEL2_MAX = 4096

NUM_CHANNELS = 12
NUM_CLASSES = 2
SIZE = 32

np.random.seed(0)
random.seed(0)

filenames = ['512x512.hdf5', 'chabud_test.h5']
fold_mapping = {'train': [1, 2, 3, 4], 'val': [0], 'test': ['chabud']}

uris = [
'feb08801-64b1-4d11-a3fc-0efaad1f4274_0',
'e4d4dbcb-dd92-40cf-a7fe-fda8dd35f367_1',
'9fc8c1f4-1858-47c3-953e-1dc8b179a',
'3a1358a2-6155-445a-a269-13bebd9741a8_0',
'2f8e659c-f457-4527-a57f-bffc3bbe0baa_0',
'299ee670-19b1-4a76-bef3-34fd55580711_1',
'05cfef86-3e27-42be-a0cb-a61fe2f89e40_0',
'0328d12a-4ad8-4504-8ac5-70089db10b4e_1',
'04800581-b540-4f9b-9df8-7ee433e83f46_0',
'108ae2a9-d7d6-42f7-b89a-90bb75c23ccb_0',
'29413474-04b8-4bb1-8b89-fd640023d4a6_0',
'43f2e60a-73b4-4f33-b99e-319d892fcab4_0',
]
folds = random.choices(fold_mapping['train'], k=4) + [0] * 4 + ['chabud'] * 4
files = ['512x512.hdf5'] * 8 + ['chabud_test.h5'] * 4

# Remove old data
for filename in filenames:
if os.path.exists(filename):
os.remove(filename)

# Create dataset file
data = np.random.randint(
SENTINEL2_MAX, size=(SIZE, SIZE, NUM_CHANNELS), dtype=np.uint16
)
gt = np.random.randint(NUM_CLASSES, size=(SIZE, SIZE, 1), dtype=np.uint16)

for filename, uri, fold in zip(files, uris, folds):
with h5py.File(filename, 'a') as f:
sample = f.create_group(uri)
sample.attrs.create(
name='fold', data=np.int64(fold) if fold != 'chabud' else fold
)
sample.create_dataset
sample.create_dataset('pre_fire', data=data)
sample.create_dataset('post_fire', data=data)
sample.create_dataset('mask', data=gt)

# Compute checksums
for filename in filenames:
with open(filename, 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f'{filename} md5: {md5}')
92 changes: 92 additions & 0 deletions tests/datasets/test_cabuar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from itertools import product
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch

from torchgeo.datasets import CaBuAr, DatasetNotFoundError

pytest.importorskip('h5py', minversion='3.6')


class TestCaBuAr:
@pytest.fixture(
params=product([CaBuAr.all_bands, CaBuAr.rgb_bands], ['train', 'val', 'test'])
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> CaBuAr:
data_dir = os.path.join('tests', 'data', 'cabuar')
urls = (
os.path.join(data_dir, '512x512.hdf5'),
os.path.join(data_dir, 'chabud_test.h5'),
)
monkeypatch.setattr(CaBuAr, 'urls', urls)
bands, split = request.param
root = tmp_path
transforms = nn.Identity()
return CaBuAr(
root=root,
split=split,
bands=bands,
transforms=transforms,
download=True,
checksum=True,
)

def test_getitem(self, dataset: CaBuAr) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
assert isinstance(x['mask'], torch.Tensor)

# Image tests
assert x['image'].ndim == 3

if dataset.bands == CaBuAr.rgb_bands:
assert x['image'].shape[0] == 2 * 3
elif dataset.bands == CaBuAr.all_bands:
assert x['image'].shape[0] == 2 * 12

# Mask tests:
assert x['mask'].ndim == 2

def test_len(self, dataset: CaBuAr) -> None:
assert len(dataset) == 4

def test_already_downloaded(self, dataset: CaBuAr) -> None:
CaBuAr(root=dataset.root, download=True)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
CaBuAr(tmp_path)

def test_invalid_bands(self) -> None:
with pytest.raises(AssertionError):
CaBuAr(bands=('OK', 'BK'))

def test_plot(self, dataset: CaBuAr) -> None:
dataset.plot(dataset[0], suptitle='Test')
plt.close()

sample = dataset[0]
sample['prediction'] = sample['mask'].clone()
dataset.plot(sample, suptitle='prediction')
plt.close()

def test_plot_rgb(self, dataset: CaBuAr) -> None:
dataset = CaBuAr(root=dataset.root, bands=('B02',))
with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"):
dataset.plot(dataset[0], suptitle='Single Band')

def test_invalid_split(self, dataset: CaBuAr) -> None:
with pytest.raises(AssertionError):
CaBuAr(dataset.root, split='foo')
3 changes: 2 additions & 1 deletion tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class TestSemanticSegmentationTask:
'name',
[
'agrifieldnet',
'cabuar',
'chabud',
'chesapeake_cvpr_5',
'chesapeake_cvpr_7',
Expand Down Expand Up @@ -83,7 +84,7 @@ def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
match name:
case 'chabud':
case 'chabud' | 'cabuar':
pytest.importorskip('h5py', minversion='3.6')
case 'landcoverai':
sha256 = (
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .agrifieldnet import AgriFieldNetDataModule
from .bigearthnet import BigEarthNetDataModule
from .cabuar import CaBuArDataModule
from .chabud import ChaBuDDataModule
from .chesapeake import ChesapeakeCVPRDataModule
from .cowc import COWCCountingDataModule
Expand Down Expand Up @@ -65,6 +66,7 @@
'SouthAfricaCropTypeDataModule',
# NonGeoDataset
'BigEarthNetDataModule',
'CaBuArDataModule',
'ChaBuDDataModule',
'COWCCountingDataModule',
'DeepGlobeLandCoverDataModule',
Expand Down
67 changes: 67 additions & 0 deletions torchgeo/datamodules/cabuar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""CaBuAr datamodule."""

from typing import Any

import torch
from einops import repeat

from ..datasets import CaBuAr
from .geo import NonGeoDataModule


class CaBuArDataModule(NonGeoDataModule):
"""LightningDataModule implementation for the CaBuAr dataset.
Uses the train/val/test splits from the dataset
.. versionadded:: 0.6
"""

# min/max values computed on train set using 2/98 percentiles
min = torch.tensor(
[0.0, 1.0, 73.0, 39.0, 46.0, 25.0, 26.0, 21.0, 17.0, 1.0, 20.0, 21.0]
)
max = torch.tensor(
[
1926.0,
2174.0,
2527.0,
2950.0,
3237.0,
3717.0,
4087.0,
4271.0,
4290.0,
4219.0,
4568.0,
3753.0,
]
)

def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a new CaBuArDataModule instance.
Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.CaBuAr`.
"""
bands = kwargs.get('bands', CaBuAr.all_bands)
band_indices = [CaBuAr.all_bands.index(b) for b in bands]
mins = self.min[band_indices]
maxs = self.max[band_indices]

# Change detection, 2 images from different times
mins = repeat(mins, 'c -> (t c)', t=2)
maxs = repeat(maxs, 'c -> (t c)', t=2)

self.mean = mins
self.std = maxs - mins

super().__init__(CaBuAr, batch_size, num_workers, **kwargs)
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .benin_cashews import BeninSmallHolderCashews
from .bigearthnet import BigEarthNet
from .biomassters import BioMassters
from .cabuar import CaBuAr
from .cbf import CanadianBuildingFootprints
from .cdl import CDL
from .chabud import ChaBuD
Expand Down Expand Up @@ -199,6 +200,7 @@
'BeninSmallHolderCashews',
'BigEarthNet',
'BioMassters',
'CaBuAr',
'ChaBuD',
'CloudCoverDetection',
'COWC',
Expand Down
Loading

0 comments on commit ccc314c

Please sign in to comment.