Skip to content

Commit

Permalink
Add Inria datamodule (microsoft#498)
Browse files Browse the repository at this point in the history
* Add Inria Datamodule

* Fix up

* Add predict.py

* Integrate kornia fns for extracting & combining

Requires kornia/kornia#1558 to be merged

* transform creates problem when calculating metrics

* Update

* Use dict.get

* Add tests & update test data

* Add Inria datamodule to docs

* Reduce test data size

* Datamodules always have predict_dataloader

* Remove comments

* Update predict.py

* Add PredictDataset

* Fix tests

* Update inria.yaml

* Clarify predict_on doc

* Refactor

* Update min kornia

* Update inria.yaml

* Remove predict utilities

* Trainer fix

* Use kornia's compute_padding

* kornia docfix

* Use stable docs

* Fixes
  • Loading branch information
ashnair1 authored and remtav committed May 26, 2022
1 parent 4855c8f commit eeac9f7
Show file tree
Hide file tree
Showing 29 changed files with 389 additions and 20 deletions.
30 changes: 30 additions & 0 deletions conf/inria.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
program:
overwrite: True


trainer:
gpus: 1
min_epochs: 5
max_epochs: 100
benchmark: True
log_every_n_steps: 2

experiment:
task: "inria"
name: "inria_test"
module:
loss: "ce"
segmentation_model: "unet"
encoder_name: "resnet18"
encoder_weights: "imagenet"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
num_classes: 2
ignore_zeros: True # class 0 not used for scoring
datamodule:
root_dir: "data/inria"
batch_size: 2
num_workers: 32
patch_size: 512
num_patches_per_tile: 4
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ FAIR1M (Fine-grAined object recognItion in high-Resolution imagery)

.. autoclass:: FAIR1MDataModule

Inria Aerial Image Labeling
^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: InriaAerialImageLabelingDataModule

LandCover.ai (Land Cover from Aerial Imagery)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@

# sphinx.ext.intersphinx
intersphinx_mapping = {
"kornia": ("https://kornia.readthedocs.io/en/stable/", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"python": ("https://docs.python.org/3", None),
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ install_requires =
einops
# fiona 1.5+ required for fiona.transform module
fiona>=1.5
# kornia 0.5.11+ required for kornia.augmentation.PadTo
kornia>=0.5.11
# kornia 0.6.4+ required for kornia.contrib.compute_padding
kornia>=0.6.4
matplotlib
numpy
# omegaconf 2.1+ required for to_object method
Expand Down
20 changes: 20 additions & 0 deletions tests/conf/inria.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
experiment:
task: "inria"
module:
loss: "ce"
segmentation_model: "unet"
encoder_name: "resnet18"
encoder_weights: "imagenet"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
num_classes: 2
ignore_zeros: True # class 0 not used for scoring
datamodule:
root_dir: "tests/data/inria"
batch_size: 1
num_workers: 0
val_split_pct: 0.2
test_split_pct: 0.2
patch_size: 2
num_patches_per_tile: 2
Binary file modified tests/data/inria/AerialImageDataset/test/images/austin10.tif
Binary file not shown.
Binary file modified tests/data/inria/AerialImageDataset/test/images/austin11.tif
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified tests/data/inria/AerialImageDataset/train/gt/austin1.tif
Binary file not shown.
Binary file modified tests/data/inria/AerialImageDataset/train/gt/austin2.tif
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified tests/data/inria/AerialImageDataset/train/images/austin1.tif
Binary file not shown.
Binary file modified tests/data/inria/AerialImageDataset/train/images/austin2.tif
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified tests/data/inria/NEW2-AerialImageDataset.zip
Binary file not shown.
6 changes: 3 additions & 3 deletions tests/data/inria/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def generate_test_data(root: str, n_samples: int = 2) -> str:
str: md5 hash of created archive
"""
dtype = np.dtype("uint8")
size = (64, 64)
size = (8, 8)

driver = "GTiff"
transform = Affine(0.3, 0.0, 616500.0, 0.0, -0.3, 3345000.0)
Expand Down Expand Up @@ -83,9 +83,9 @@ def generate_test_data(root: str, n_samples: int = 2) -> str:
archive_path, "zip", root_dir=root, base_dir="AerialImageDataset"
)
shutil.rmtree(folder_path)
return calculate_md5(archive_path + ".zip")
return calculate_md5(f"{archive_path}.zip")


if __name__ == "__main__":
md5_hash = generate_test_data(os.getcwd(), 2)
md5_hash = generate_test_data(os.getcwd(), 5)
print(md5_hash)
70 changes: 70 additions & 0 deletions tests/datamodules/test_inria.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import pytest
from _pytest.fixtures import SubRequest

from torchgeo.datamodules import InriaAerialImageLabelingDataModule

TEST_DATA_DIR = os.path.join("tests", "data", "inria")


class TestInriaAerialImageLabelingDataModule:
@pytest.fixture(
params=zip([0.2, 0.2, 0.0], [0.2, 0.0, 0.0], ["test", "test", "test"])
)
def datamodule(self, request: SubRequest) -> InriaAerialImageLabelingDataModule:
val_split_pct, test_split_pct, predict_on = request.param
patch_size = 2 # (2,2)
num_patches_per_tile = 2
root = TEST_DATA_DIR
batch_size = 1
num_workers = 0
dm = InriaAerialImageLabelingDataModule(
root,
batch_size,
num_workers,
val_split_pct,
test_split_pct,
patch_size,
num_patches_per_tile,
predict_on=predict_on,
)
dm.prepare_data()
dm.setup()
return dm

def test_train_dataloader(
self, datamodule: InriaAerialImageLabelingDataModule
) -> None:
sample = next(iter(datamodule.train_dataloader()))
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2)
assert sample["image"].shape[0] == sample["mask"].shape[0] == 2
assert sample["image"].shape[1] == 3
assert sample["mask"].shape[1] == 1

def test_val_dataloader(
self, datamodule: InriaAerialImageLabelingDataModule
) -> None:
sample = next(iter(datamodule.val_dataloader()))
if datamodule.val_split_pct > 0.0:
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2)
assert sample["image"].shape[0] == sample["mask"].shape[0] == 2

def test_test_dataloader(
self, datamodule: InriaAerialImageLabelingDataModule
) -> None:
sample = next(iter(datamodule.test_dataloader()))
if datamodule.test_split_pct > 0.0:
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2)
assert sample["image"].shape[0] == sample["mask"].shape[0] == 2

def test_predict_dataloader(
self, datamodule: InriaAerialImageLabelingDataModule
) -> None:
sample = next(iter(datamodule.predict_dataloader()))
assert len(sample["image"].shape) == 5
assert sample["image"].shape[-2:] == (2, 2)
assert sample["image"].shape[2] == 3
4 changes: 2 additions & 2 deletions tests/datasets/test_inria.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def dataset(
) -> InriaAerialImageLabeling:

root = os.path.join("tests", "data", "inria")
test_md5 = "f23caf363389ef59de55fad11197c161"
test_md5 = "478688944e4797c097d9387fd0b3f038"
monkeypatch.setattr(InriaAerialImageLabeling, "md5", test_md5)
transforms = nn.Identity() # type: ignore[no-untyped-call]
return InriaAerialImageLabeling(
Expand All @@ -39,7 +39,7 @@ def test_getitem(self, dataset: InriaAerialImageLabeling) -> None:
assert x["image"].ndim == 3

def test_len(self, dataset: InriaAerialImageLabeling) -> None:
assert len(dataset) == 2
assert len(dataset) == 5

def test_already_downloaded(self, dataset: InriaAerialImageLabeling) -> None:
InriaAerialImageLabeling(root=dataset.root)
Expand Down
2 changes: 2 additions & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchgeo.datamodules import (
ChesapeakeCVPRDataModule,
ETCI2021DataModule,
InriaAerialImageLabelingDataModule,
LandCoverAIDataModule,
NAIPChesapeakeDataModule,
OSCDDataModule,
Expand All @@ -34,6 +35,7 @@ class TestSemanticSegmentationTask:
[
("chesapeake_cvpr_5", ChesapeakeCVPRDataModule),
("etci2021", ETCI2021DataModule),
("inria", InriaAerialImageLabelingDataModule),
("landcoverai", LandCoverAIDataModule),
("naipchesapeake", NAIPChesapeakeDataModule),
("oscd_all", OSCDDataModule),
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .etci2021 import ETCI2021DataModule
from .eurosat import EuroSATDataModule
from .fair1m import FAIR1MDataModule
from .inria import InriaAerialImageLabelingDataModule
from .landcoverai import LandCoverAIDataModule
from .loveda import LoveDADataModule
from .naip import NAIPChesapeakeDataModule
Expand All @@ -33,6 +34,7 @@
"ETCI2021DataModule",
"EuroSATDataModule",
"FAIR1MDataModule",
"InriaAerialImageLabelingDataModule",
"LandCoverAIDataModule",
"LoveDADataModule",
"NASAMarineDebrisDataModule",
Expand Down
Loading

0 comments on commit eeac9f7

Please sign in to comment.