diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 50741377e8d..6c06768bef0 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -224,6 +224,11 @@ LoveDA .. autoclass:: LoveDA +Million-AID +^^^^^^^^^^^ + +.. autoclass:: MillionAID + NASA Marine Debris ^^^^^^^^^^^^^^^^^^ diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv index 2e37e18e21b..a7bdb09add2 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/non_geo_datasets.csv @@ -16,6 +16,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands `LandCover.ai`_,S,Aerial,"10,674",5,512x512,0.25--0.5,RGB `LEVIR-CD+`_,CD,Google Earth,985,2,"1,024x1,024",0.5,RGB `LoveDA`_,S,Google Earth,"5,987",7,"1,024x1,024",0.3,RGB +`Million-AID`_,C,Google Earth,1M,51--73,,0.5--153,RGB `NASA Marine Debris`_,OD,PlanetScope,707,1,256x256,3,RGB `OSCD`_,CD,Sentinel-2,24,2,"40--1,180",60,MSI `PatternNet`_,C,Google Earth,"30,400",38,256x256,0.06--5,RGB diff --git a/tests/data/millionaid/data.py b/tests/data/millionaid/data.py new file mode 100644 index 00000000000..03ea05ad5df --- /dev/null +++ b/tests/data/millionaid/data.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil + +import numpy as np +from PIL import Image + +SIZE = 32 + +np.random.seed(0) + +PATHS = { + "train": [ + os.path.join( + "train", "agriculture_land", "grassland", "meadow", "P0115918.jpg" + ), + os.path.join("train", "water_area", "beach", "P0060208.jpg"), + ], + "test": [ + os.path.join("test", "agriculture_land", "grassland", "meadow", "P0115918.jpg"), + os.path.join("test", "water_area", "beach", "P0060208.jpg"), + ], +} + + +def create_file(path: str) -> None: + Z = np.random.rand(SIZE, SIZE, 3) * 255 + img = Image.fromarray(Z.astype("uint8")).convert("RGB") + img.save(path) + + +if __name__ == "__main__": + for split, paths in PATHS.items(): + # remove old data + if os.path.isdir(split): + shutil.rmtree(split) + for path in paths: + os.makedirs(os.path.dirname(path), exist_ok=True) + create_file(path) + + # compress data + shutil.make_archive(split, "zip", ".", split) + + # Compute checksums + with open(split + ".zip", "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"{split}: {md5}") diff --git a/tests/data/millionaid/test.zip b/tests/data/millionaid/test.zip new file mode 100644 index 00000000000..15a0bbd3f46 Binary files /dev/null and b/tests/data/millionaid/test.zip differ diff --git a/tests/data/millionaid/test/agriculture_land/grassland/meadow/P0115918.jpg b/tests/data/millionaid/test/agriculture_land/grassland/meadow/P0115918.jpg new file mode 100644 index 00000000000..1d04bf8a523 Binary files /dev/null and b/tests/data/millionaid/test/agriculture_land/grassland/meadow/P0115918.jpg differ diff --git a/tests/data/millionaid/test/water_area/beach/P0060208.jpg b/tests/data/millionaid/test/water_area/beach/P0060208.jpg new file mode 100644 index 00000000000..226ff70e36b Binary files /dev/null and b/tests/data/millionaid/test/water_area/beach/P0060208.jpg differ diff --git a/tests/data/millionaid/train.zip b/tests/data/millionaid/train.zip new file mode 100644 index 00000000000..c3fdae60c55 Binary files /dev/null and b/tests/data/millionaid/train.zip differ diff --git a/tests/data/millionaid/train/agriculture_land/grassland/meadow/P0115918.jpg b/tests/data/millionaid/train/agriculture_land/grassland/meadow/P0115918.jpg new file mode 100644 index 00000000000..afc980b49dd Binary files /dev/null and b/tests/data/millionaid/train/agriculture_land/grassland/meadow/P0115918.jpg differ diff --git a/tests/data/millionaid/train/water_area/beach/P0060208.jpg b/tests/data/millionaid/train/water_area/beach/P0060208.jpg new file mode 100644 index 00000000000..8d742f2c51b Binary files /dev/null and b/tests/data/millionaid/train/water_area/beach/P0060208.jpg differ diff --git a/tests/datasets/test_millionaid.py b/tests/datasets/test_millionaid.py new file mode 100644 index 00000000000..751567e28a8 --- /dev/null +++ b/tests/datasets/test_millionaid.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest + +from torchgeo.datasets import MillionAID + + +class TestMillionAID: + @pytest.fixture( + scope="class", params=zip(["train", "test"], ["multi-class", "multi-label"]) + ) + def dataset(self, request: SubRequest) -> MillionAID: + root = os.path.join("tests", "data", "millionaid") + split, task = request.param + transforms = nn.Identity() + return MillionAID( + root=root, split=split, task=task, transforms=transforms, checksum=True + ) + + def test_getitem(self, dataset: MillionAID) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["label"], torch.Tensor) + assert x["image"].shape[0] == 3 + assert x["image"].ndim == 3 + + def test_len(self, dataset: MillionAID) -> None: + assert len(dataset) == 2 + + def test_not_found(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found in"): + MillionAID(str(tmp_path)) + + def test_not_extracted(self, tmp_path: Path) -> None: + url = os.path.join("tests", "data", "millionaid", "train.zip") + shutil.copy(url, tmp_path) + MillionAID(str(tmp_path)) + + def test_corrupted(self, tmp_path: Path) -> None: + with open(os.path.join(tmp_path, "train.zip"), "w") as f: + f.write("bad") + with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + MillionAID(str(tmp_path), checksum=True) + + def test_plot(self, dataset: MillionAID) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + + def test_plot_prediction(self, dataset: MillionAID) -> None: + x = dataset[0].copy() + x["prediction"] = x["label"].clone() + dataset.plot(x) + plt.close() diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 13196988787..f3b7a314626 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -68,6 +68,7 @@ ) from .levircd import LEVIRCDPlus from .loveda import LoveDA +from .millionaid import MillionAID from .naip import NAIP from .nasa_marine_debris import NASAMarineDebris from .nwpu import VHR10 @@ -163,6 +164,7 @@ "LandCoverAI", "LEVIRCDPlus", "LoveDA", + "MillionAID", "NASAMarineDebris", "OSCD", "PatternNet", diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py new file mode 100644 index 00000000000..5136907ff4d --- /dev/null +++ b/torchgeo/datasets/millionaid.py @@ -0,0 +1,371 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Million-AID dataset.""" +import glob +import os +from typing import Any, Callable, Dict, List, Optional, cast + +import matplotlib.pyplot as plt +import numpy as np +import torch +from PIL import Image +from torch import Tensor + +from torchgeo.datasets import VisionDataset + +from .utils import check_integrity, extract_archive + + +class MillionAID(VisionDataset): + """Million-AID Dataset. + + The `MillionAID `_ dataset consists + of one million aerial images from Google Earth Engine that offers + either `a multi-class learning task + `_ + with 51 classes or a `multi-label learning task + `_ + with 73 different possible labels. For more details please consult + the accompanying `paper `_. + + Dataset features: + + * RGB aerial images with varying resolutions from 0.5 m to 153 m per pixel + * images within classes can have different pixel dimension + + Dataset format: + + * images are three-channel jpg + + If you use this dataset in your research, please cite the following paper: + + * https://ieeexplore.ieee.org/document/9393553 + + .. versionadded:: 0.3 + """ + + multi_label_categories = [ + "agriculture_land", + "airport_area", + "apartment", + "apron", + "arable_land", + "bare_land", + "baseball_field", + "basketball_court", + "beach", + "bridge", + "cemetery", + "church", + "commercial_area", + "commercial_land", + "dam", + "desert", + "detached_house", + "dry_field", + "factory_area", + "forest", + "golf_course", + "grassland", + "greenhouse", + "ground_track_field", + "helipad", + "highway_area", + "ice_land", + "industrial_land", + "intersection", + "island", + "lake", + "leisure_land", + "meadow", + "mine", + "mining_area", + "mobile_home_park", + "oil_field", + "orchard", + "paddy_field", + "parking_lot", + "pier", + "port_area", + "power_station", + "public_service_land", + "quarry", + "railway", + "railway_area", + "religious_land", + "residential_land", + "river", + "road", + "rock_land", + "roundabout", + "runway", + "solar_power_plant", + "sparse_shrub_land", + "special_land", + "sports_land", + "stadium", + "storage_tank", + "substation", + "swimming_pool", + "tennis_court", + "terraced_field", + "train_station", + "transportation_land", + "unutilized_land", + "viaduct", + "wastewater_plant", + "water_area", + "wind_turbine", + "woodland", + "works", + ] + + multi_class_categories = [ + "apartment", + "apron", + "bare_land", + "baseball_field", + "bapsketball_court", + "beach", + "bridge", + "cemetery", + "church", + "commercial_area", + "dam", + "desert", + "detached_house", + "dry_field", + "forest", + "golf_course", + "greenhouse", + "ground_track_field", + "helipad", + "ice_land", + "intersection", + "island", + "lake", + "meadow", + "mine", + "mobile_home_park", + "oil_field", + "orchard", + "paddy_field", + "parking_lot", + "pier", + "quarry", + "railway", + "river", + "road", + "rock_land", + "roundabout", + "runway", + "solar_power_plant", + "sparse_shrub_land", + "stadium", + "storage_tank", + "substation", + "swimming_pool", + "tennis_court", + "terraced_field", + "train_station", + "viaduct", + "wastewater_plant", + "wind_turbine", + "works", + ] + + md5s = { + "train": "1b40503cafa9b0601653ca36cd788852", + "test": "51a63ee3eeb1351889eacff349a983d8", + } + + filenames = {"train": "train.zip", "test": "test.zip"} + + tasks = ["multi-class", "multi-label"] + splits = ["train", "test"] + + def __init__( + self, + root: str = "data", + task: str = "multi-class", + split: str = "train", + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + checksum: bool = False, + ) -> None: + """Initialize a new MillionAID dataset instance. + + Args: + root: root directory where dataset can be found + task: type of task, either "multi-class" or "multi-label" + split: train or test split + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + RuntimeError: if dataset is not found + """ + self.root = root + self.transforms = transforms + self.checksum = checksum + assert task in self.tasks + assert split in self.splits + self.task = task + self.split = split + + self._verify() + + self.files = self._load_files(self.root) + + self.classes = sorted({cls for f in self.files for cls in f["label"]}) + self.class_to_idx: Dict[str, int] = {c: i for i, c in enumerate(self.classes)} + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.files) + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + files = self.files[index] + image = self._load_image(files["image"]) + cls_label = [self.class_to_idx[label] for label in files["label"]] + label = torch.tensor(cls_label, dtype=torch.long) + sample = {"image": image, "label": label} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _load_files(self, root: str) -> List[Dict[str, Any]]: + """Return the paths of the files in the dataset. + + Args: + root: root directory of dataset + + Returns: + list of dicts containing paths for each pair of image, and list of labels + """ + imgs_no_subcat = list( + glob.glob(os.path.join(root, self.split, "*", "*", "*.jpg")) + ) + + imgs_subcat = list( + glob.glob(os.path.join(root, self.split, "*", "*", "*", "*.jpg")) + ) + + scenes = [p.split(os.sep)[-3] for p in imgs_no_subcat] + [ + p.split(os.sep)[-4] for p in imgs_subcat + ] + + subcategories = ["Missing" for p in imgs_no_subcat] + [ + p.split(os.sep)[-3] for p in imgs_subcat + ] + + classes = [p.split(os.sep)[-2] for p in imgs_no_subcat] + [ + p.split(os.sep)[-2] for p in imgs_subcat + ] + + if self.task == "multi-label": + labels = [ + [sc, sub, c] if sub != "Missing" else [sc, c] + for sc, sub, c in zip(scenes, subcategories, classes) + ] + else: + labels = [[c] for c in classes] + + images = imgs_no_subcat + imgs_subcat + + files = [dict(image=img, label=l) for img, l in zip(images, labels)] + + return files + + def _load_image(self, path: str) -> Tensor: + """Load a single image. + + Args: + path: path to the image + + Returns: + the image + """ + with Image.open(path) as img: + array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) + tensor: Tensor = torch.from_numpy(array) + # Convert from HxWxC to CxHxW + tensor = tensor.permute((2, 0, 1)) + return tensor + + def _verify(self) -> None: + """Checks the integrity of the dataset structure. + + Returns: + True if the dataset directories are found, else False + """ + filepath = os.path.join(self.root, self.split) + if os.path.isdir(filepath): + return + + filepath = os.path.join(self.root, self.split + ".zip") + if os.path.isfile(filepath): + if self.checksum and not check_integrity(filepath, self.md5s[self.split]): + raise RuntimeError("Dataset found, but corrupted.") + extract_archive(filepath) + return + + raise RuntimeError( + f"Dataset not found in `root={self.root}` directory, either " + "specify a different `root` directory or manually download " + "the dataset to this directory." + ) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + """ + image = np.rollaxis(sample["image"].numpy(), 0, 3) + labels = [self.classes[cast(int, label)] for label in sample["label"]] + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction_labels = [ + self.classes[cast(int, label)] for label in sample["prediction"] + ] + + fig, ax = plt.subplots(figsize=(4, 4)) + ax.imshow(image) + ax.axis("off") + if show_titles: + title = f"Label: {labels}" + if showing_predictions: + title += f"\nPrediction: {prediction_labels}" + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + return fig