Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CaBuAr dataset #2235

Merged
merged 7 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ BigEarthNet

.. autoclass:: BigEarthNetDataModule

CaBuAr
^^^^^^

.. autoclass:: CaBuArDataModule

ChaBuD
^^^^^^

Expand Down
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ BioMassters

.. autoclass:: BioMassters

CaBuAr
^^^^^^

.. autoclass:: CaBuAr

ChaBuD
^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`Benin Cashew Plantations`_,S,Airbus Pléiades,"CC-BY-4.0",70,6,"1,122x1,186",10,MSI
`BigEarthNet`_,C,Sentinel-1/2,"CDLA-Permissive-1.0","590,326",19--43,120x120,10,"SAR, MSI"
`BioMassters`_,R,Sentinel-1/2 and Lidar,"CC-BY-4.0",,,256x256, 10, "SAR, MSI"
`CaBuAr`_,CD,Sentinel-2,"OpenRAIL",424,2,512x512,20,MSI
`ChaBuD`_,CD,Sentinel-2,"OpenRAIL",356,2,512x512,10,MSI
`Cloud Cover Detection`_,S,Sentinel-2,"CC-BY-4.0","22,728",2,512x512,10,MSI
`COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","AGPL-3.0-only","388,435",2,256x256,0.15,RGB
Expand Down
16 changes: 16 additions & 0 deletions tests/conf/cabuar.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "unet"
backbone: "resnet18"
in_channels: 24
num_classes: 2
num_filters: 1
ignore_index: null
data:
class_path: CaBuArDataModule
init_args:
batch_size: 2
dict_kwargs:
root: "tests/data/cabuar"
Binary file added tests/data/cabuar/512x512.hdf5
Binary file not shown.
Binary file added tests/data/cabuar/chabud_test.h5
Binary file not shown.
69 changes: 69 additions & 0 deletions tests/data/cabuar/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import random

import h5py
import numpy as np

# Sentinel-2 is 12-bit with range 0-4095
SENTINEL2_MAX = 4096

NUM_CHANNELS = 12
NUM_CLASSES = 2
SIZE = 32

np.random.seed(0)
random.seed(0)

filenames = ['512x512.hdf5', 'chabud_test.h5']
fold_mapping = {'train': [1, 2, 3, 4], 'val': [0], 'test': ['chabud']}

uris = [
'feb08801-64b1-4d11-a3fc-0efaad1f4274_0',
'e4d4dbcb-dd92-40cf-a7fe-fda8dd35f367_1',
'9fc8c1f4-1858-47c3-953e-1dc8b179a',
'3a1358a2-6155-445a-a269-13bebd9741a8_0',
'2f8e659c-f457-4527-a57f-bffc3bbe0baa_0',
'299ee670-19b1-4a76-bef3-34fd55580711_1',
'05cfef86-3e27-42be-a0cb-a61fe2f89e40_0',
'0328d12a-4ad8-4504-8ac5-70089db10b4e_1',
'04800581-b540-4f9b-9df8-7ee433e83f46_0',
'108ae2a9-d7d6-42f7-b89a-90bb75c23ccb_0',
'29413474-04b8-4bb1-8b89-fd640023d4a6_0',
'43f2e60a-73b4-4f33-b99e-319d892fcab4_0',
]
folds = random.choices(fold_mapping['train'], k=4) + [0] * 4 + ['chabud'] * 4
files = ['512x512.hdf5'] * 8 + ['chabud_test.h5'] * 4

# Remove old data
for filename in filenames:
if os.path.exists(filename):
os.remove(filename)

# Create dataset file
data = np.random.randint(
SENTINEL2_MAX, size=(SIZE, SIZE, NUM_CHANNELS), dtype=np.uint16
)
gt = np.random.randint(NUM_CLASSES, size=(SIZE, SIZE, 1), dtype=np.uint16)

for filename, uri, fold in zip(files, uris, folds):
with h5py.File(filename, 'a') as f:
sample = f.create_group(uri)
sample.attrs.create(
name='fold', data=np.int64(fold) if fold != 'chabud' else fold
)
sample.create_dataset
sample.create_dataset('pre_fire', data=data)
sample.create_dataset('post_fire', data=data)
sample.create_dataset('mask', data=gt)

# Compute checksums
for filename in filenames:
with open(filename, 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f'{filename} md5: {md5}')
92 changes: 92 additions & 0 deletions tests/datasets/test_cabuar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from itertools import product
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch

from torchgeo.datasets import CaBuAr, DatasetNotFoundError

pytest.importorskip('h5py', minversion='3.6')


class TestCaBuAr:
@pytest.fixture(
params=product([CaBuAr.all_bands, CaBuAr.rgb_bands], ['train', 'val', 'test'])
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> CaBuAr:
data_dir = os.path.join('tests', 'data', 'cabuar')
urls = (
os.path.join(data_dir, '512x512.hdf5'),
os.path.join(data_dir, 'chabud_test.h5'),
)
monkeypatch.setattr(CaBuAr, 'urls', urls)
bands, split = request.param
root = tmp_path
transforms = nn.Identity()
return CaBuAr(
root=root,
split=split,
bands=bands,
transforms=transforms,
download=True,
checksum=True,
)

def test_getitem(self, dataset: CaBuAr) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
assert isinstance(x['mask'], torch.Tensor)

# Image tests
assert x['image'].ndim == 3

if dataset.bands == CaBuAr.rgb_bands:
assert x['image'].shape[0] == 2 * 3
elif dataset.bands == CaBuAr.all_bands:
assert x['image'].shape[0] == 2 * 12

# Mask tests:
assert x['mask'].ndim == 2

def test_len(self, dataset: CaBuAr) -> None:
assert len(dataset) == 4

def test_already_downloaded(self, dataset: CaBuAr) -> None:
CaBuAr(root=dataset.root, download=True)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
CaBuAr(tmp_path)

def test_invalid_bands(self) -> None:
with pytest.raises(AssertionError):
CaBuAr(bands=('OK', 'BK'))

def test_plot(self, dataset: CaBuAr) -> None:
dataset.plot(dataset[0], suptitle='Test')
plt.close()

sample = dataset[0]
sample['prediction'] = sample['mask'].clone()
dataset.plot(sample, suptitle='prediction')
plt.close()

def test_plot_rgb(self, dataset: CaBuAr) -> None:
dataset = CaBuAr(root=dataset.root, bands=('B02',))
with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"):
dataset.plot(dataset[0], suptitle='Single Band')

def test_invalid_split(self, dataset: CaBuAr) -> None:
with pytest.raises(AssertionError):
CaBuAr(dataset.root, split='foo')
3 changes: 2 additions & 1 deletion tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class TestSemanticSegmentationTask:
'name',
[
'agrifieldnet',
'cabuar',
'chabud',
'chesapeake_cvpr_5',
'chesapeake_cvpr_7',
Expand Down Expand Up @@ -82,7 +83,7 @@ def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
match name:
case 'chabud':
case 'chabud' | 'cabuar':
pytest.importorskip('h5py', minversion='3.6')
case 'landcoverai':
sha256 = (
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .agrifieldnet import AgriFieldNetDataModule
from .bigearthnet import BigEarthNetDataModule
from .cabuar import CaBuArDataModule
from .chabud import ChaBuDDataModule
from .chesapeake import ChesapeakeCVPRDataModule
from .cowc import COWCCountingDataModule
Expand Down Expand Up @@ -64,6 +65,7 @@
'SouthAfricaCropTypeDataModule',
# NonGeoDataset
'BigEarthNetDataModule',
'CaBuArDataModule',
'ChaBuDDataModule',
'COWCCountingDataModule',
'DeepGlobeLandCoverDataModule',
Expand Down
67 changes: 67 additions & 0 deletions torchgeo/datamodules/cabuar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""CaBuAr datamodule."""

from typing import Any

import torch
from einops import repeat

from ..datasets import CaBuAr
from .geo import NonGeoDataModule


class CaBuArDataModule(NonGeoDataModule):
"""LightningDataModule implementation for the CaBuAr dataset.

Uses the train/val/test splits from the dataset

.. versionadded:: 0.6
"""

# min/max values computed on train set using 2/98 percentiles
min = torch.tensor(
[0.0, 1.0, 73.0, 39.0, 46.0, 25.0, 26.0, 21.0, 17.0, 1.0, 20.0, 21.0]
)
max = torch.tensor(
[
1926.0,
2174.0,
2527.0,
2950.0,
3237.0,
3717.0,
4087.0,
4271.0,
4290.0,
4219.0,
4568.0,
3753.0,
]
)

def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a new CaBuArDataModule instance.

Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.CaBuAr`.
"""
bands = kwargs.get('bands', CaBuAr.all_bands)
band_indices = [CaBuAr.all_bands.index(b) for b in bands]
mins = self.min[band_indices]
maxs = self.max[band_indices]

# Change detection, 2 images from different times
mins = repeat(mins, 'c -> (t c)', t=2)
maxs = repeat(maxs, 'c -> (t c)', t=2)

self.mean = mins
self.std = maxs - mins

super().__init__(CaBuAr, batch_size, num_workers, **kwargs)
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .benin_cashews import BeninSmallHolderCashews
from .bigearthnet import BigEarthNet
from .biomassters import BioMassters
from .cabuar import CaBuAr
from .cbf import CanadianBuildingFootprints
from .cdl import CDL
from .chabud import ChaBuD
Expand Down Expand Up @@ -197,6 +198,7 @@
'BeninSmallHolderCashews',
'BigEarthNet',
'BioMassters',
'CaBuAr',
'ChaBuD',
'CloudCoverDetection',
'COWC',
Expand Down
Loading
Loading