Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CaBuAr dataset #2235

Merged
merged 7 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
DarthReca marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ BioMassters

.. autoclass:: BioMassters

CaBuAr
^^^^^^

.. autoclass:: CaBuAr

ChaBuD
^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`BigEarthNet`_,C,Sentinel-1/2,"CDLA-Permissive-1.0","590,326",19--43,120x120,10,"SAR, MSI"
`BioMassters`_,R,Sentinel-1/2 and Lidar,"CC-BY-4.0",,,256x256, 10, "SAR, MSI"
`ChaBuD`_,CD,Sentinel-2,"OpenRAIL",356,2,512x512,10,MSI
`CaBuAr`_,CD,Sentinel-2,"OpenRAIL",424,2,512x512,20,MSI
DarthReca marked this conversation as resolved.
Show resolved Hide resolved
`Cloud Cover Detection`_,S,Sentinel-2,"CC-BY-4.0","22,728",2,512x512,10,MSI
`COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","AGPL-3.0-only","388,435",2,256x256,0.15,RGB
`CropHarvest`_,"C","Sentinel-1/2, SRTM, ERA5","CC-BY-SA-4.0","70,213",351,1x1,10,"SAR, MSI, SRTM"
Expand Down
Binary file added tests/data/cabuar/512x512.hdf5
Binary file not shown.
Binary file added tests/data/cabuar/chabud_test.h5
Binary file not shown.
80 changes: 80 additions & 0 deletions tests/data/cabuar/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import random

import h5py
import numpy as np

# Sentinel-2 is 12-bit with range 0-4095
SENTINEL2_MAX = 4096

NUM_CHANNELS = 12
NUM_CLASSES = 2
SIZE = 32

np.random.seed(0)
random.seed(0)

filenames = ['512x512.hdf5', 'chabud_test.h5']
fold_mapping = {'train': [1, 2, 3, 4], 'val': [0], 'test': ['chabud']}

uris = [
'feb08801-64b1-4d11-a3fc-0efaad1f4274_0',
'e4d4dbcb-dd92-40cf-a7fe-fda8dd35f367_1',
'9fc8c1f4-1858-47c3-953e-1dc8b179a',
'3a1358a2-6155-445a-a269-13bebd9741a8_0',
'2f8e659c-f457-4527-a57f-bffc3bbe0baa_0',
'299ee670-19b1-4a76-bef3-34fd55580711_1',
'05cfef86-3e27-42be-a0cb-a61fe2f89e40_0',
'0328d12a-4ad8-4504-8ac5-70089db10b4e_1',
'04800581-b540-4f9b-9df8-7ee433e83f46_0',
'108ae2a9-d7d6-42f7-b89a-90bb75c23ccb_0',
'29413474-04b8-4bb1-8b89-fd640023d4a6_0',
'43f2e60a-73b4-4f33-b99e-319d892fcab4_0',
]
folds = (
[
random.sample(fold_mapping['train'], 1)[0],
random.sample(fold_mapping['train'], 1)[0],
random.sample(fold_mapping['train'], 1)[0],
random.sample(fold_mapping['train'], 1)[0],
DarthReca marked this conversation as resolved.
Show resolved Hide resolved
]
+ [0] * 4
+ ['chabud'] * 4
)
files = ['512x512.hdf5'] * 8 + ['chabud_test.h5'] * 4


# Remove old data
for filename in filenames:
if os.path.exists(filename):
os.remove(filename)

# Create dataset file
data = np.random.randint(
SENTINEL2_MAX, size=(SIZE, SIZE, NUM_CHANNELS), dtype=np.uint16
)
data = data.astype(np.uint16)
DarthReca marked this conversation as resolved.
Show resolved Hide resolved
gt = np.random.randint(NUM_CLASSES, size=(SIZE, SIZE, 1), dtype=np.uint16)

for filename, uri, fold in zip(files, uris, folds):
with h5py.File(filename, 'a') as f:
sample = f.create_group(uri)
sample.attrs.create(
name='fold', data=np.int64(fold) if fold != 'chabud' else fold
)
sample.create_dataset
sample.create_dataset('pre_fire', data=data)
sample.create_dataset('post_fire', data=data)
sample.create_dataset('mask', data=gt)

# Compute checksums
for filename in filenames:
with open(filename, 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f'{filename} md5: {md5}')
40 changes: 40 additions & 0 deletions tests/datamodules/test_cabuar.py
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import matplotlib.pyplot as plt
import pytest

from torchgeo.datamodules import CaBuArDataModule
from torchgeo.datasets import unbind_samples


class TestCaBuArDataModule:
@pytest.fixture
def datamodule(self) -> CaBuArDataModule:
root = os.path.join('tests', 'data', 'cabuar')
batch_size = 1
num_workers = 0
dm = CaBuArDataModule(root=root, batch_size=batch_size, num_workers=num_workers)
dm.prepare_data()
return dm

def test_train_dataloader(self, datamodule: CaBuArDataModule) -> None:
datamodule.setup('fit')
next(iter(datamodule.train_dataloader()))

def test_val_dataloader(self, datamodule: CaBuArDataModule) -> None:
datamodule.setup('validate')
next(iter(datamodule.val_dataloader()))

def test_test_dataloader(self, datamodule: CaBuArDataModule) -> None:
datamodule.setup('test')
next(iter(datamodule.test_dataloader()))

def test_plot(self, datamodule: CaBuArDataModule) -> None:
datamodule.setup('validate')
batch = next(iter(datamodule.val_dataloader()))
sample = unbind_samples(batch)[0]
datamodule.plot(sample)
plt.close()
102 changes: 102 additions & 0 deletions tests/datasets/test_cabuar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch

import torchgeo.datasets.utils
from torchgeo.datasets import CaBuAr, DatasetNotFoundError

pytest.importorskip('h5py', minversion='3.6')


def download_url(
DarthReca marked this conversation as resolved.
Show resolved Hide resolved
url: str, root: str | Path, filename: str, *args: str, **kwargs: str
) -> None:
shutil.copy(url, os.path.join(root, filename))


class TestCaBuAr:
@pytest.fixture(
params=zip([CaBuAr.all_bands, CaBuAr.rgb_bands], ['train', 'val', 'test'])
DarthReca marked this conversation as resolved.
Show resolved Hide resolved
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> CaBuAr:
monkeypatch.setattr(torchgeo.datasets.cabuar, 'download_url', download_url)
data_dir = os.path.join('tests', 'data', 'cabuar')
urls = (
os.path.join(data_dir, '512x512.hdf5'),
os.path.join(data_dir, 'chabud_test.h5'),
)
md5s = ('fd7d2f800562a5bb2c9f101ebb9104b2', '41ba3903e7d9db2d549c72261d6a6d53')
monkeypatch.setattr(CaBuAr, 'urls', urls)
monkeypatch.setattr(CaBuAr, 'md5s', md5s)
DarthReca marked this conversation as resolved.
Show resolved Hide resolved
bands, split = request.param
root = tmp_path
transforms = nn.Identity()
return CaBuAr(
root=root,
split=split,
bands=bands,
transforms=transforms,
download=True,
checksum=True,
)

def test_getitem(self, dataset: CaBuAr) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
assert isinstance(x['mask'], torch.Tensor)

# Image tests
assert x['image'].ndim == 3

if dataset.bands == CaBuAr.rgb_bands:
assert x['image'].shape[0] == 2 * 3
elif dataset.bands == CaBuAr.all_bands:
assert x['image'].shape[0] == 2 * 12

# Mask tests:
assert x['mask'].ndim == 2

def test_len(self, dataset: CaBuAr) -> None:
assert len(dataset) == 4

def test_already_downloaded(self, dataset: CaBuAr) -> None:
CaBuAr(root=dataset.root, download=True)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
CaBuAr(tmp_path)

def test_invalid_bands(self) -> None:
with pytest.raises(AssertionError):
CaBuAr(bands=['OK', 'BK'])

def test_plot(self, dataset: CaBuAr) -> None:
dataset.plot(dataset[0], suptitle='Test')
plt.close()

sample = dataset[0]
sample['prediction'] = sample['mask'].clone()
dataset.plot(sample, suptitle='prediction')
plt.close()

def test_plot_rgb(self, dataset: CaBuAr) -> None:
dataset = CaBuAr(root=dataset.root, bands=['B02'])
with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"):
dataset.plot(dataset[0], suptitle='Single Band')

def test_invalid_split(self, dataset: CaBuAr) -> None:
with pytest.raises(AssertionError):
CaBuAr(dataset.root, split='foo')
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .agrifieldnet import AgriFieldNetDataModule
from .bigearthnet import BigEarthNetDataModule
from .cabuar import CaBuArDataModule
from .chabud import ChaBuDDataModule
from .chesapeake import ChesapeakeCVPRDataModule
from .cowc import COWCCountingDataModule
Expand Down Expand Up @@ -64,6 +65,7 @@
'SouthAfricaCropTypeDataModule',
# NonGeoDataset
'BigEarthNetDataModule',
'CaBuArDataModule',
'ChaBuDDataModule',
'COWCCountingDataModule',
'DeepGlobeLandCoverDataModule',
Expand Down
79 changes: 79 additions & 0 deletions torchgeo/datamodules/cabuar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""CaBuAr datamodule."""

from typing import Any

import torch
from einops import repeat

from ..datasets import CaBuAr
from .geo import NonGeoDataModule


class CaBuArDataModule(NonGeoDataModule):
"""LightningDataModule implementation for the CaBuAr dataset.

Uses the train/val/test splits from the dataset

.. versionadded:: 0.6
"""

# min/max values computed on train set using 2/98 percentiles
min = torch.tensor(
[0.0, 1.0, 73.0, 39.0, 46.0, 25.0, 26.0, 21.0, 17.0, 1.0, 20.0, 21.0]
)
max = torch.tensor(
[
1926.0,
2174.0,
2527.0,
2950.0,
3237.0,
3717.0,
4087.0,
4271.0,
4290.0,
4219.0,
4568.0,
3753.0,
]
)

def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a new CaBuArDataModule instance.

Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.CaBuAr`.
"""
bands = kwargs.get('bands', CaBuAr.all_bands)
band_indices = [CaBuAr.all_bands.index(b) for b in bands]
mins = self.min[band_indices]
maxs = self.max[band_indices]

# Change detection, 2 images from different times
mins = repeat(mins, 'c -> (t c)', t=2)
maxs = repeat(maxs, 'c -> (t c)', t=2)

self.mean = mins
self.std = maxs - mins

super().__init__(CaBuAr, batch_size, num_workers, **kwargs)

def setup(self, stage: str) -> None:
DarthReca marked this conversation as resolved.
Show resolved Hide resolved
"""Set up datasets.

Args:
stage: Either 'fit', 'validate', 'test'.
"""
if stage in ['fit', 'validate']:
self.train_dataset = CaBuAr(split='train', **self.kwargs)
self.val_dataset = CaBuAr(split='val', **self.kwargs)
elif stage == 'test':
self.test_dataset = CaBuAr(split='test', **self.kwargs)
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .benin_cashews import BeninSmallHolderCashews
from .bigearthnet import BigEarthNet
from .biomassters import BioMassters
from .cabuar import CaBuAr
from .cbf import CanadianBuildingFootprints
from .cdl import CDL
from .chabud import ChaBuD
Expand Down Expand Up @@ -198,6 +199,7 @@
'BeninSmallHolderCashews',
'BigEarthNet',
'BioMassters',
'CaBuAr',
'ChaBuD',
'CloudCoverDetection',
'COWC',
Expand Down
Loading
Loading