Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Inria datamodule #498

Merged
merged 26 commits into from
May 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions conf/inria.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
program:
overwrite: True


trainer:
gpus: 1
min_epochs: 5
max_epochs: 100
benchmark: True
log_every_n_steps: 2

experiment:
task: "inria"
name: "inria_test"
module:
loss: "ce"
segmentation_model: "unet"
encoder_name: "resnet18"
encoder_weights: "imagenet"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
num_classes: 2
ignore_zeros: True # class 0 not used for scoring
datamodule:
root_dir: "data/inria"
batch_size: 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty small... Note that conf/inria.yaml should contain the optimal hyperparameters while tests/conf/inria.yaml should contain the bare minimum to get the tests to run.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tests/conf/inria.yaml

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that the small values required to make the tests run fast are in tests/conf/inria.yaml, should these values be increased? I guess it requires a hyperparam tuning to determine what the best hyperparams are.

num_workers: 32
patch_size: 512
num_patches_per_tile: 4
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ FAIR1M (Fine-grAined object recognItion in high-Resolution imagery)

.. autoclass:: FAIR1MDataModule

Inria Aerial Image Labeling
^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: InriaAerialImageLabelingDataModule

LandCover.ai (Land Cover from Aerial Imagery)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@

# sphinx.ext.intersphinx
intersphinx_mapping = {
"kornia": ("https://kornia.readthedocs.io/en/stable/", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"python": ("https://docs.python.org/3", None),
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ install_requires =
einops
# fiona 1.5+ required for fiona.transform module
fiona>=1.5
# kornia 0.5.11+ required for kornia.augmentation.PadTo
kornia>=0.5.11
# kornia 0.6.4+ required for kornia.contrib.compute_padding
kornia>=0.6.4
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
matplotlib
numpy
# omegaconf 2.1+ required for to_object method
Expand Down
20 changes: 20 additions & 0 deletions tests/conf/inria.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
experiment:
task: "inria"
module:
loss: "ce"
segmentation_model: "unet"
encoder_name: "resnet18"
encoder_weights: "imagenet"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
num_classes: 2
ignore_zeros: True # class 0 not used for scoring
datamodule:
root_dir: "tests/data/inria"
batch_size: 1
num_workers: 0
val_split_pct: 0.2
test_split_pct: 0.2
patch_size: 2
num_patches_per_tile: 2
Binary file modified tests/data/inria/AerialImageDataset/test/images/austin10.tif
Binary file not shown.
Binary file modified tests/data/inria/AerialImageDataset/test/images/austin11.tif
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified tests/data/inria/AerialImageDataset/train/gt/austin1.tif
Binary file not shown.
Binary file modified tests/data/inria/AerialImageDataset/train/gt/austin2.tif
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified tests/data/inria/AerialImageDataset/train/images/austin1.tif
Binary file not shown.
Binary file modified tests/data/inria/AerialImageDataset/train/images/austin2.tif
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified tests/data/inria/NEW2-AerialImageDataset.zip
Binary file not shown.
6 changes: 3 additions & 3 deletions tests/data/inria/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def generate_test_data(root: str, n_samples: int = 2) -> str:
str: md5 hash of created archive
"""
dtype = np.dtype("uint8")
size = (64, 64)
size = (8, 8)

driver = "GTiff"
transform = Affine(0.3, 0.0, 616500.0, 0.0, -0.3, 3345000.0)
Expand Down Expand Up @@ -83,9 +83,9 @@ def generate_test_data(root: str, n_samples: int = 2) -> str:
archive_path, "zip", root_dir=root, base_dir="AerialImageDataset"
)
shutil.rmtree(folder_path)
return calculate_md5(archive_path + ".zip")
return calculate_md5(f"{archive_path}.zip")


if __name__ == "__main__":
md5_hash = generate_test_data(os.getcwd(), 2)
md5_hash = generate_test_data(os.getcwd(), 5)
print(md5_hash)
70 changes: 70 additions & 0 deletions tests/datamodules/test_inria.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
# Licensed under the MIT License.

import os

import pytest
from _pytest.fixtures import SubRequest

from torchgeo.datamodules import InriaAerialImageLabelingDataModule

TEST_DATA_DIR = os.path.join("tests", "data", "inria")


class TestInriaAerialImageLabelingDataModule:
@pytest.fixture(
params=zip([0.2, 0.2, 0.0], [0.2, 0.0, 0.0], ["test", "test", "test"])
)
def datamodule(self, request: SubRequest) -> InriaAerialImageLabelingDataModule:
val_split_pct, test_split_pct, predict_on = request.param
patch_size = 2 # (2,2)
num_patches_per_tile = 2
root = TEST_DATA_DIR
batch_size = 1
num_workers = 0
dm = InriaAerialImageLabelingDataModule(
root,
batch_size,
num_workers,
val_split_pct,
test_split_pct,
patch_size,
num_patches_per_tile,
predict_on=predict_on,
)
dm.prepare_data()
dm.setup()
return dm

def test_train_dataloader(
self, datamodule: InriaAerialImageLabelingDataModule
) -> None:
sample = next(iter(datamodule.train_dataloader()))
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2)
assert sample["image"].shape[0] == sample["mask"].shape[0] == 2
assert sample["image"].shape[1] == 3
assert sample["mask"].shape[1] == 1

def test_val_dataloader(
self, datamodule: InriaAerialImageLabelingDataModule
) -> None:
sample = next(iter(datamodule.val_dataloader()))
if datamodule.val_split_pct > 0.0:
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2)
assert sample["image"].shape[0] == sample["mask"].shape[0] == 2

def test_test_dataloader(
self, datamodule: InriaAerialImageLabelingDataModule
) -> None:
sample = next(iter(datamodule.test_dataloader()))
if datamodule.test_split_pct > 0.0:
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2)
assert sample["image"].shape[0] == sample["mask"].shape[0] == 2

def test_predict_dataloader(
self, datamodule: InriaAerialImageLabelingDataModule
) -> None:
sample = next(iter(datamodule.predict_dataloader()))
assert len(sample["image"].shape) == 5
assert sample["image"].shape[-2:] == (2, 2)
assert sample["image"].shape[2] == 3
4 changes: 2 additions & 2 deletions tests/datasets/test_inria.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def dataset(
) -> InriaAerialImageLabeling:

root = os.path.join("tests", "data", "inria")
test_md5 = "f23caf363389ef59de55fad11197c161"
test_md5 = "478688944e4797c097d9387fd0b3f038"
monkeypatch.setattr(InriaAerialImageLabeling, "md5", test_md5)
transforms = nn.Identity() # type: ignore[no-untyped-call]
return InriaAerialImageLabeling(
Expand All @@ -39,7 +39,7 @@ def test_getitem(self, dataset: InriaAerialImageLabeling) -> None:
assert x["image"].ndim == 3

def test_len(self, dataset: InriaAerialImageLabeling) -> None:
assert len(dataset) == 2
assert len(dataset) == 5

def test_already_downloaded(self, dataset: InriaAerialImageLabeling) -> None:
InriaAerialImageLabeling(root=dataset.root)
Expand Down
2 changes: 2 additions & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchgeo.datamodules import (
ChesapeakeCVPRDataModule,
ETCI2021DataModule,
InriaAerialImageLabelingDataModule,
LandCoverAIDataModule,
NAIPChesapeakeDataModule,
OSCDDataModule,
Expand All @@ -34,6 +35,7 @@ class TestSemanticSegmentationTask:
[
("chesapeake_cvpr_5", ChesapeakeCVPRDataModule),
("etci2021", ETCI2021DataModule),
("inria", InriaAerialImageLabelingDataModule),
("landcoverai", LandCoverAIDataModule),
("naipchesapeake", NAIPChesapeakeDataModule),
("oscd_all", OSCDDataModule),
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .etci2021 import ETCI2021DataModule
from .eurosat import EuroSATDataModule
from .fair1m import FAIR1MDataModule
from .inria import InriaAerialImageLabelingDataModule
from .landcoverai import LandCoverAIDataModule
from .loveda import LoveDADataModule
from .naip import NAIPChesapeakeDataModule
Expand All @@ -33,6 +34,7 @@
"ETCI2021DataModule",
"EuroSATDataModule",
"FAIR1MDataModule",
"InriaAerialImageLabelingDataModule",
"LandCoverAIDataModule",
"LoveDADataModule",
"NASAMarineDebrisDataModule",
Expand Down
Loading