Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add InriaAerialImageLabeling dataset #355

Merged
merged 17 commits into from
Jan 13, 2022
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ IDTReeS

.. autoclass:: IDTReeS

Inria Aerial Image Labeling
^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: InriaAerialImageLabeling

LandCover.ai (Land Cover from Aerial Imagery)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/inria/NEW2-AerialImageDataset.zip
Binary file not shown.
91 changes: 91 additions & 0 deletions tests/data/inria/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
ashnair1 marked this conversation as resolved.
Show resolved Hide resolved
import shutil

import numpy as np
import rasterio as rio
from rasterio.crs import CRS
from rasterio.transform import Affine
from torchvision.datasets.utils import calculate_md5


def write_data(
path: str, img: np.ndarray, driver: str, crs: CRS, transform: Affine
) -> None:
with rio.open(
path,
"w",
driver=driver,
height=img.shape[0],
width=img.shape[1],
count=3,
dtype=img.dtype,
crs=crs,
transform=transform,
) as dst:
for i in range(1, dst.count + 1):
dst.write(img, i)


def generate_test_data(root: str, n_samples: int = 2) -> str:
"""Creates test data archive for InriaAerialImageLabeling dataset and
returns its md5 hash.

Args:
root (str): Path to store test data
n_samples (int, optional): Number of samples. Defaults to 2.

Returns:
str: md5 hash of created archive
"""
dtype = np.dtype("uint8")
size = (64, 64)

driver = "GTiff"
transform = Affine(0.3, 0.0, 616500.0, 0.0, -0.3, 3345000.0)
crs = CRS.from_epsg(26914)

folder_path = os.path.join(root, "AerialImageDataset")

img_dir = os.path.join(folder_path, "train", "images")
lbl_dir = os.path.join(folder_path, "train", "gt")
timg_dir = os.path.join(folder_path, "test", "images")

if not os.path.exists(img_dir):
os.makedirs(img_dir)
if not os.path.exists(lbl_dir):
os.makedirs(lbl_dir)
if not os.path.exists(timg_dir):
os.makedirs(timg_dir)

for i in range(n_samples):

dtype_max = np.iinfo(dtype).max
img = np.random.randint(dtype_max, size=size, dtype=dtype)
lbl = np.random.randint(2, size=size, dtype=dtype)
timg = np.random.randint(dtype_max, size=size, dtype=dtype)

img_path = os.path.join(img_dir, f"austin{i+1}.tif")
lbl_path = os.path.join(lbl_dir, f"austin{i+1}.tif")
timg_path = os.path.join(timg_dir, f"austin{i+10}.tif")

write_data(img_path, img, driver, crs, transform)
write_data(lbl_path, lbl, driver, crs, transform)
write_data(timg_path, timg, driver, crs, transform)

# Create archive
archive_path = os.path.join(root, "NEW2-AerialImageDataset")
shutil.make_archive(
archive_path, "zip", root_dir=root, base_dir="AerialImageDataset"
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
)
shutil.rmtree(folder_path)
return calculate_md5(archive_path + ".zip")
ashnair1 marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
md5_hash = generate_test_data(os.getcwd(), 2)
print(md5_hash)
67 changes: 67 additions & 0 deletions tests/datasets/test_inria.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
ashnair1 marked this conversation as resolved.
Show resolved Hide resolved
import shutil
from typing import Generator

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch

from torchgeo.datasets import InriaAerialImageLabeling


class TestInriaAerialImageLabeling:
@pytest.fixture(params=["train", "test"])
def dataset(
self, request: SubRequest, monkeypatch: Generator[MonkeyPatch, None, None]
) -> InriaAerialImageLabeling:

root = os.path.join("tests", "data", "inria")
test_md5 = "f23caf363389ef59de55fad11197c161"
monkeypatch.setattr( # type: ignore[attr-defined]
InriaAerialImageLabeling, "md5", test_md5
)
transforms = nn.Identity() # type: ignore[attr-defined]
return InriaAerialImageLabeling(
root, split=request.param, transforms=transforms, checksum=True
)

def test_getitem(self, dataset: InriaAerialImageLabeling) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
if dataset.split == "train":
assert isinstance(x["mask"], torch.Tensor)
assert x["mask"].ndim == 2
assert x["image"].shape[0] == 3
ashnair1 marked this conversation as resolved.
Show resolved Hide resolved
assert x["image"].ndim == 3

def test_len(self, dataset: InriaAerialImageLabeling) -> None:
assert len(dataset) == 2

def test_already_downloaded(self, dataset: InriaAerialImageLabeling) -> None:
InriaAerialImageLabeling(root=dataset.root)

def test_not_downloaded(self, tmp_path: str) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
InriaAerialImageLabeling(str(tmp_path))

def test_dataset_checksum(self, dataset: InriaAerialImageLabeling) -> None:
InriaAerialImageLabeling.md5 = "randommd5hash123"
shutil.rmtree(os.path.join(dataset.root, dataset.directory))
with pytest.raises(RuntimeError, match="Dataset corrupted"):
InriaAerialImageLabeling(root=dataset.root, checksum=True)

def test_plot(self, dataset: InriaAerialImageLabeling) -> None:
x = dataset[0].copy()
if dataset.split == "train":
x["prediction"] = x["mask"]
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from .gid15 import GID15
from .idtrees import IDTReeS
from .inria import InriaAerialImageLabeling
from .landcoverai import LandCoverAI
from .landsat import (
Landsat,
Expand Down Expand Up @@ -120,6 +121,7 @@
"FAIR1M",
"GID15",
"IDTReeS",
"InriaAerialImageLabeling",
"LandCoverAI",
"LEVIRCDPlus",
"LoveDA",
Expand Down
Loading