diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index 50741377e8d..6c06768bef0 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -224,6 +224,11 @@ LoveDA
.. autoclass:: LoveDA
+Million-AID
+^^^^^^^^^^^
+
+.. autoclass:: MillionAID
+
NASA Marine Debris
^^^^^^^^^^^^^^^^^^
diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv
index 2e37e18e21b..a7bdb09add2 100644
--- a/docs/api/non_geo_datasets.csv
+++ b/docs/api/non_geo_datasets.csv
@@ -16,6 +16,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
`LandCover.ai`_,S,Aerial,"10,674",5,512x512,0.25--0.5,RGB
`LEVIR-CD+`_,CD,Google Earth,985,2,"1,024x1,024",0.5,RGB
`LoveDA`_,S,Google Earth,"5,987",7,"1,024x1,024",0.3,RGB
+`Million-AID`_,C,Google Earth,1M,51--73,,0.5--153,RGB
`NASA Marine Debris`_,OD,PlanetScope,707,1,256x256,3,RGB
`OSCD`_,CD,Sentinel-2,24,2,"40--1,180",60,MSI
`PatternNet`_,C,Google Earth,"30,400",38,256x256,0.06--5,RGB
diff --git a/tests/data/millionaid/data.py b/tests/data/millionaid/data.py
new file mode 100644
index 00000000000..03ea05ad5df
--- /dev/null
+++ b/tests/data/millionaid/data.py
@@ -0,0 +1,52 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import hashlib
+import os
+import shutil
+
+import numpy as np
+from PIL import Image
+
+SIZE = 32
+
+np.random.seed(0)
+
+PATHS = {
+ "train": [
+ os.path.join(
+ "train", "agriculture_land", "grassland", "meadow", "P0115918.jpg"
+ ),
+ os.path.join("train", "water_area", "beach", "P0060208.jpg"),
+ ],
+ "test": [
+ os.path.join("test", "agriculture_land", "grassland", "meadow", "P0115918.jpg"),
+ os.path.join("test", "water_area", "beach", "P0060208.jpg"),
+ ],
+}
+
+
+def create_file(path: str) -> None:
+ Z = np.random.rand(SIZE, SIZE, 3) * 255
+ img = Image.fromarray(Z.astype("uint8")).convert("RGB")
+ img.save(path)
+
+
+if __name__ == "__main__":
+ for split, paths in PATHS.items():
+ # remove old data
+ if os.path.isdir(split):
+ shutil.rmtree(split)
+ for path in paths:
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ create_file(path)
+
+ # compress data
+ shutil.make_archive(split, "zip", ".", split)
+
+ # Compute checksums
+ with open(split + ".zip", "rb") as f:
+ md5 = hashlib.md5(f.read()).hexdigest()
+ print(f"{split}: {md5}")
diff --git a/tests/data/millionaid/test.zip b/tests/data/millionaid/test.zip
new file mode 100644
index 00000000000..15a0bbd3f46
Binary files /dev/null and b/tests/data/millionaid/test.zip differ
diff --git a/tests/data/millionaid/test/agriculture_land/grassland/meadow/P0115918.jpg b/tests/data/millionaid/test/agriculture_land/grassland/meadow/P0115918.jpg
new file mode 100644
index 00000000000..1d04bf8a523
Binary files /dev/null and b/tests/data/millionaid/test/agriculture_land/grassland/meadow/P0115918.jpg differ
diff --git a/tests/data/millionaid/test/water_area/beach/P0060208.jpg b/tests/data/millionaid/test/water_area/beach/P0060208.jpg
new file mode 100644
index 00000000000..226ff70e36b
Binary files /dev/null and b/tests/data/millionaid/test/water_area/beach/P0060208.jpg differ
diff --git a/tests/data/millionaid/train.zip b/tests/data/millionaid/train.zip
new file mode 100644
index 00000000000..c3fdae60c55
Binary files /dev/null and b/tests/data/millionaid/train.zip differ
diff --git a/tests/data/millionaid/train/agriculture_land/grassland/meadow/P0115918.jpg b/tests/data/millionaid/train/agriculture_land/grassland/meadow/P0115918.jpg
new file mode 100644
index 00000000000..afc980b49dd
Binary files /dev/null and b/tests/data/millionaid/train/agriculture_land/grassland/meadow/P0115918.jpg differ
diff --git a/tests/data/millionaid/train/water_area/beach/P0060208.jpg b/tests/data/millionaid/train/water_area/beach/P0060208.jpg
new file mode 100644
index 00000000000..8d742f2c51b
Binary files /dev/null and b/tests/data/millionaid/train/water_area/beach/P0060208.jpg differ
diff --git a/tests/datasets/test_millionaid.py b/tests/datasets/test_millionaid.py
new file mode 100644
index 00000000000..751567e28a8
--- /dev/null
+++ b/tests/datasets/test_millionaid.py
@@ -0,0 +1,64 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import os
+import shutil
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import pytest
+import torch
+import torch.nn as nn
+from _pytest.fixtures import SubRequest
+
+from torchgeo.datasets import MillionAID
+
+
+class TestMillionAID:
+ @pytest.fixture(
+ scope="class", params=zip(["train", "test"], ["multi-class", "multi-label"])
+ )
+ def dataset(self, request: SubRequest) -> MillionAID:
+ root = os.path.join("tests", "data", "millionaid")
+ split, task = request.param
+ transforms = nn.Identity()
+ return MillionAID(
+ root=root, split=split, task=task, transforms=transforms, checksum=True
+ )
+
+ def test_getitem(self, dataset: MillionAID) -> None:
+ x = dataset[0]
+ assert isinstance(x, dict)
+ assert isinstance(x["image"], torch.Tensor)
+ assert isinstance(x["label"], torch.Tensor)
+ assert x["image"].shape[0] == 3
+ assert x["image"].ndim == 3
+
+ def test_len(self, dataset: MillionAID) -> None:
+ assert len(dataset) == 2
+
+ def test_not_found(self, tmp_path: Path) -> None:
+ with pytest.raises(RuntimeError, match="Dataset not found in"):
+ MillionAID(str(tmp_path))
+
+ def test_not_extracted(self, tmp_path: Path) -> None:
+ url = os.path.join("tests", "data", "millionaid", "train.zip")
+ shutil.copy(url, tmp_path)
+ MillionAID(str(tmp_path))
+
+ def test_corrupted(self, tmp_path: Path) -> None:
+ with open(os.path.join(tmp_path, "train.zip"), "w") as f:
+ f.write("bad")
+ with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
+ MillionAID(str(tmp_path), checksum=True)
+
+ def test_plot(self, dataset: MillionAID) -> None:
+ x = dataset[0].copy()
+ dataset.plot(x, suptitle="Test")
+ plt.close()
+
+ def test_plot_prediction(self, dataset: MillionAID) -> None:
+ x = dataset[0].copy()
+ x["prediction"] = x["label"].clone()
+ dataset.plot(x)
+ plt.close()
diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py
index 13196988787..f3b7a314626 100644
--- a/torchgeo/datasets/__init__.py
+++ b/torchgeo/datasets/__init__.py
@@ -68,6 +68,7 @@
)
from .levircd import LEVIRCDPlus
from .loveda import LoveDA
+from .millionaid import MillionAID
from .naip import NAIP
from .nasa_marine_debris import NASAMarineDebris
from .nwpu import VHR10
@@ -163,6 +164,7 @@
"LandCoverAI",
"LEVIRCDPlus",
"LoveDA",
+ "MillionAID",
"NASAMarineDebris",
"OSCD",
"PatternNet",
diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py
new file mode 100644
index 00000000000..5136907ff4d
--- /dev/null
+++ b/torchgeo/datasets/millionaid.py
@@ -0,0 +1,371 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""Million-AID dataset."""
+import glob
+import os
+from typing import Any, Callable, Dict, List, Optional, cast
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from PIL import Image
+from torch import Tensor
+
+from torchgeo.datasets import VisionDataset
+
+from .utils import check_integrity, extract_archive
+
+
+class MillionAID(VisionDataset):
+ """Million-AID Dataset.
+
+ The `MillionAID `_ dataset consists
+ of one million aerial images from Google Earth Engine that offers
+ either `a multi-class learning task
+ `_
+ with 51 classes or a `multi-label learning task
+ `_
+ with 73 different possible labels. For more details please consult
+ the accompanying `paper `_.
+
+ Dataset features:
+
+ * RGB aerial images with varying resolutions from 0.5 m to 153 m per pixel
+ * images within classes can have different pixel dimension
+
+ Dataset format:
+
+ * images are three-channel jpg
+
+ If you use this dataset in your research, please cite the following paper:
+
+ * https://ieeexplore.ieee.org/document/9393553
+
+ .. versionadded:: 0.3
+ """
+
+ multi_label_categories = [
+ "agriculture_land",
+ "airport_area",
+ "apartment",
+ "apron",
+ "arable_land",
+ "bare_land",
+ "baseball_field",
+ "basketball_court",
+ "beach",
+ "bridge",
+ "cemetery",
+ "church",
+ "commercial_area",
+ "commercial_land",
+ "dam",
+ "desert",
+ "detached_house",
+ "dry_field",
+ "factory_area",
+ "forest",
+ "golf_course",
+ "grassland",
+ "greenhouse",
+ "ground_track_field",
+ "helipad",
+ "highway_area",
+ "ice_land",
+ "industrial_land",
+ "intersection",
+ "island",
+ "lake",
+ "leisure_land",
+ "meadow",
+ "mine",
+ "mining_area",
+ "mobile_home_park",
+ "oil_field",
+ "orchard",
+ "paddy_field",
+ "parking_lot",
+ "pier",
+ "port_area",
+ "power_station",
+ "public_service_land",
+ "quarry",
+ "railway",
+ "railway_area",
+ "religious_land",
+ "residential_land",
+ "river",
+ "road",
+ "rock_land",
+ "roundabout",
+ "runway",
+ "solar_power_plant",
+ "sparse_shrub_land",
+ "special_land",
+ "sports_land",
+ "stadium",
+ "storage_tank",
+ "substation",
+ "swimming_pool",
+ "tennis_court",
+ "terraced_field",
+ "train_station",
+ "transportation_land",
+ "unutilized_land",
+ "viaduct",
+ "wastewater_plant",
+ "water_area",
+ "wind_turbine",
+ "woodland",
+ "works",
+ ]
+
+ multi_class_categories = [
+ "apartment",
+ "apron",
+ "bare_land",
+ "baseball_field",
+ "bapsketball_court",
+ "beach",
+ "bridge",
+ "cemetery",
+ "church",
+ "commercial_area",
+ "dam",
+ "desert",
+ "detached_house",
+ "dry_field",
+ "forest",
+ "golf_course",
+ "greenhouse",
+ "ground_track_field",
+ "helipad",
+ "ice_land",
+ "intersection",
+ "island",
+ "lake",
+ "meadow",
+ "mine",
+ "mobile_home_park",
+ "oil_field",
+ "orchard",
+ "paddy_field",
+ "parking_lot",
+ "pier",
+ "quarry",
+ "railway",
+ "river",
+ "road",
+ "rock_land",
+ "roundabout",
+ "runway",
+ "solar_power_plant",
+ "sparse_shrub_land",
+ "stadium",
+ "storage_tank",
+ "substation",
+ "swimming_pool",
+ "tennis_court",
+ "terraced_field",
+ "train_station",
+ "viaduct",
+ "wastewater_plant",
+ "wind_turbine",
+ "works",
+ ]
+
+ md5s = {
+ "train": "1b40503cafa9b0601653ca36cd788852",
+ "test": "51a63ee3eeb1351889eacff349a983d8",
+ }
+
+ filenames = {"train": "train.zip", "test": "test.zip"}
+
+ tasks = ["multi-class", "multi-label"]
+ splits = ["train", "test"]
+
+ def __init__(
+ self,
+ root: str = "data",
+ task: str = "multi-class",
+ split: str = "train",
+ transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
+ checksum: bool = False,
+ ) -> None:
+ """Initialize a new MillionAID dataset instance.
+
+ Args:
+ root: root directory where dataset can be found
+ task: type of task, either "multi-class" or "multi-label"
+ split: train or test split
+ transforms: a function/transform that takes input sample and its target as
+ entry and returns a transformed version
+ checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ RuntimeError: if dataset is not found
+ """
+ self.root = root
+ self.transforms = transforms
+ self.checksum = checksum
+ assert task in self.tasks
+ assert split in self.splits
+ self.task = task
+ self.split = split
+
+ self._verify()
+
+ self.files = self._load_files(self.root)
+
+ self.classes = sorted({cls for f in self.files for cls in f["label"]})
+ self.class_to_idx: Dict[str, int] = {c: i for i, c in enumerate(self.classes)}
+
+ def __len__(self) -> int:
+ """Return the number of data points in the dataset.
+
+ Returns:
+ length of the dataset
+ """
+ return len(self.files)
+
+ def __getitem__(self, index: int) -> Dict[str, Tensor]:
+ """Return an index within the dataset.
+
+ Args:
+ index: index to return
+
+ Returns:
+ data and label at that index
+ """
+ files = self.files[index]
+ image = self._load_image(files["image"])
+ cls_label = [self.class_to_idx[label] for label in files["label"]]
+ label = torch.tensor(cls_label, dtype=torch.long)
+ sample = {"image": image, "label": label}
+
+ if self.transforms is not None:
+ sample = self.transforms(sample)
+
+ return sample
+
+ def _load_files(self, root: str) -> List[Dict[str, Any]]:
+ """Return the paths of the files in the dataset.
+
+ Args:
+ root: root directory of dataset
+
+ Returns:
+ list of dicts containing paths for each pair of image, and list of labels
+ """
+ imgs_no_subcat = list(
+ glob.glob(os.path.join(root, self.split, "*", "*", "*.jpg"))
+ )
+
+ imgs_subcat = list(
+ glob.glob(os.path.join(root, self.split, "*", "*", "*", "*.jpg"))
+ )
+
+ scenes = [p.split(os.sep)[-3] for p in imgs_no_subcat] + [
+ p.split(os.sep)[-4] for p in imgs_subcat
+ ]
+
+ subcategories = ["Missing" for p in imgs_no_subcat] + [
+ p.split(os.sep)[-3] for p in imgs_subcat
+ ]
+
+ classes = [p.split(os.sep)[-2] for p in imgs_no_subcat] + [
+ p.split(os.sep)[-2] for p in imgs_subcat
+ ]
+
+ if self.task == "multi-label":
+ labels = [
+ [sc, sub, c] if sub != "Missing" else [sc, c]
+ for sc, sub, c in zip(scenes, subcategories, classes)
+ ]
+ else:
+ labels = [[c] for c in classes]
+
+ images = imgs_no_subcat + imgs_subcat
+
+ files = [dict(image=img, label=l) for img, l in zip(images, labels)]
+
+ return files
+
+ def _load_image(self, path: str) -> Tensor:
+ """Load a single image.
+
+ Args:
+ path: path to the image
+
+ Returns:
+ the image
+ """
+ with Image.open(path) as img:
+ array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB"))
+ tensor: Tensor = torch.from_numpy(array)
+ # Convert from HxWxC to CxHxW
+ tensor = tensor.permute((2, 0, 1))
+ return tensor
+
+ def _verify(self) -> None:
+ """Checks the integrity of the dataset structure.
+
+ Returns:
+ True if the dataset directories are found, else False
+ """
+ filepath = os.path.join(self.root, self.split)
+ if os.path.isdir(filepath):
+ return
+
+ filepath = os.path.join(self.root, self.split + ".zip")
+ if os.path.isfile(filepath):
+ if self.checksum and not check_integrity(filepath, self.md5s[self.split]):
+ raise RuntimeError("Dataset found, but corrupted.")
+ extract_archive(filepath)
+ return
+
+ raise RuntimeError(
+ f"Dataset not found in `root={self.root}` directory, either "
+ "specify a different `root` directory or manually download "
+ "the dataset to this directory."
+ )
+
+ def plot(
+ self,
+ sample: Dict[str, Tensor],
+ show_titles: bool = True,
+ suptitle: Optional[str] = None,
+ ) -> plt.Figure:
+ """Plot a sample from the dataset.
+
+ Args:
+ sample: a sample returned by :meth:`__getitem__`
+ show_titles: flag indicating whether to show titles above each panel
+ suptitle: optional string to use as a suptitle
+
+ Returns:
+ a matplotlib Figure with the rendered sample
+
+ """
+ image = np.rollaxis(sample["image"].numpy(), 0, 3)
+ labels = [self.classes[cast(int, label)] for label in sample["label"]]
+
+ showing_predictions = "prediction" in sample
+ if showing_predictions:
+ prediction_labels = [
+ self.classes[cast(int, label)] for label in sample["prediction"]
+ ]
+
+ fig, ax = plt.subplots(figsize=(4, 4))
+ ax.imshow(image)
+ ax.axis("off")
+ if show_titles:
+ title = f"Label: {labels}"
+ if showing_predictions:
+ title += f"\nPrediction: {prediction_labels}"
+ ax.set_title(title)
+
+ if suptitle is not None:
+ plt.suptitle(suptitle)
+ return fig