diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 34835176851..645acbacc47 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -166,6 +166,13 @@ Sentinel .. autoclass:: Sentinel1 .. autoclass:: Sentinel2 + +South America Soybean +^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: SouthAmericaSoybean + + .. _Non-geospatial Datasets: Non-geospatial Datasets diff --git a/docs/api/geo_datasets.csv b/docs/api/geo_datasets.csv index d8596a16068..ed7655e843d 100644 --- a/docs/api/geo_datasets.csv +++ b/docs/api/geo_datasets.csv @@ -24,3 +24,4 @@ Dataset,Type,Source,License,Size (px),Resolution (m) `Open Buildings`_,Geometries,"Maxar, CNES/Airbus","CC-BY-4.0 OR ODbL-1.0",-,- `PRISMA`_,Imagery,PRISMA,-,512x512,5--30 `Sentinel`_,Imagery,Sentinel,"CC-BY-SA-3.0-IGO","10,000x10,000",10 +`South America Soybean`_,Masks,"Landsat, MODIS",-,-,30 diff --git a/tests/data/south_america_soybean/SouthAmericaSoybean.zip b/tests/data/south_america_soybean/SouthAmericaSoybean.zip new file mode 100644 index 00000000000..5453b89fc25 Binary files /dev/null and b/tests/data/south_america_soybean/SouthAmericaSoybean.zip differ diff --git a/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2002.tif b/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2002.tif new file mode 100644 index 00000000000..95667ce067c Binary files /dev/null and b/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2002.tif differ diff --git a/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2021.tif b/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2021.tif new file mode 100644 index 00000000000..a220b500677 Binary files /dev/null and b/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2021.tif differ diff --git a/tests/data/south_america_soybean/data.py b/tests/data/south_america_soybean/data.py new file mode 100644 index 00000000000..fbe7d7b23d1 --- /dev/null +++ b/tests/data/south_america_soybean/data.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import hashlib +import os +import shutil + +import numpy as np +import rasterio +from rasterio.crs import CRS +from rasterio.transform import Affine + +SIZE = 32 + + +np.random.seed(0) +files = ["South_America_Soybean_2002.tif", "South_America_Soybean_2021.tif"] + + +def create_file(path: str, dtype: str): + """Create the testing file.""" + profile = { + "driver": "GTiff", + "dtype": dtype, + "count": 1, + "crs": CRS.from_epsg(4326), + "transform": Affine( + 0.0002499999999999943131, + 0.0, + -82.0005000000000024, + 0.0, + -0.0002499999999999943131, + 0.0005000000000000, + ), + "height": SIZE, + "width": SIZE, + "compress": "lzw", + "predictor": 2, + } + + allowed_values = [0, 1] + + Z = np.random.choice(allowed_values, size=(SIZE, SIZE)) + + with rasterio.open(path, "w", **profile) as src: + src.write(Z, 1) + + +if __name__ == "__main__": + dir = os.path.join(os.getcwd(), "SouthAmericaSoybean") + if os.path.exists(dir) and os.path.isdir(dir): + shutil.rmtree(dir) + + os.makedirs(dir, exist_ok=True) + + for file in files: + create_file(os.path.join(dir, file), dtype="int8") + + # Compress data + shutil.make_archive("SouthAmericaSoybean", "zip", ".", dir) + + # Compute checksums + with open("SouthAmericaSoybean.zip", "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"SouthAmericaSoybean.zip: {md5}") diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py new file mode 100644 index 00000000000..11dcc2b5ff9 --- /dev/null +++ b/tests/datasets/test_south_america_soybean.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from pytest import MonkeyPatch +from rasterio.crs import CRS + +import torchgeo.datasets.utils +from torchgeo.datasets import ( + BoundingBox, + DatasetNotFoundError, + IntersectionDataset, + SouthAmericaSoybean, + UnionDataset, +) + + +def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: + shutil.copy(url, root) + + +class TestSouthAmericaSoybean: + @pytest.fixture + def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SouthAmericaSoybean: + monkeypatch.setattr( + torchgeo.datasets.south_america_soybean, "download_url", download_url + ) + transforms = nn.Identity() + url = os.path.join( + "tests", + "data", + "south_america_soybean", + "SouthAmericaSoybean", + "South_America_Soybean_{}.tif", + ) + + monkeypatch.setattr(SouthAmericaSoybean, "url", url) + root = str(tmp_path) + return SouthAmericaSoybean( + paths=root, + years=[2002, 2021], + transforms=transforms, + download=True, + checksum=True, + ) + + def test_getitem(self, dataset: SouthAmericaSoybean) -> None: + x = dataset[dataset.bounds] + assert isinstance(x, dict) + assert isinstance(x["crs"], CRS) + assert isinstance(x["mask"], torch.Tensor) + + def test_and(self, dataset: SouthAmericaSoybean) -> None: + ds = dataset & dataset + assert isinstance(ds, IntersectionDataset) + + def test_or(self, dataset: SouthAmericaSoybean) -> None: + ds = dataset | dataset + assert isinstance(ds, UnionDataset) + + def test_already_extracted(self, dataset: SouthAmericaSoybean) -> None: + SouthAmericaSoybean(dataset.paths, download=True) + + def test_already_downloaded(self, tmp_path: Path) -> None: + pathname = os.path.join( + "tests", + "data", + "south_america_soybean", + "SouthAmericaSoybean", + "South_America_Soybean_2002.tif", + ) + root = str(tmp_path) + shutil.copy(pathname, root) + SouthAmericaSoybean(root) + + def test_plot(self, dataset: SouthAmericaSoybean) -> None: + query = dataset.bounds + x = dataset[query] + dataset.plot(x, suptitle="Test") + plt.close() + + def test_plot_prediction(self, dataset: SouthAmericaSoybean) -> None: + query = dataset.bounds + x = dataset[query] + x["prediction"] = x["mask"].clone() + dataset.plot(x, suptitle="Prediction") + plt.close() + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + SouthAmericaSoybean(str(tmp_path)) + + def test_invalid_query(self, dataset: SouthAmericaSoybean) -> None: + query = BoundingBox(0, 0, 0, 0, 0, 0) + with pytest.raises( + IndexError, match="query: .* not found in index with bounds:" + ): + dataset[query] diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index fd20414e995..235ab83c8eb 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -98,6 +98,7 @@ from .sentinel import Sentinel, Sentinel1, Sentinel2 from .skippd import SKIPPD from .so2sat import So2Sat +from .south_america_soybean import SouthAmericaSoybean from .spacenet import ( SpaceNet, SpaceNet1, @@ -185,6 +186,7 @@ "Sentinel", "Sentinel1", "Sentinel2", + "SouthAmericaSoybean", # NonGeoDataset "ADVANCE", "BeninSmallHolderCashews", diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py new file mode 100644 index 00000000000..e1d1e952cf0 --- /dev/null +++ b/torchgeo/datasets/south_america_soybean.py @@ -0,0 +1,175 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""South America Soybean Dataset.""" + +from collections.abc import Iterable +from typing import Any, Callable, Optional, Union + +import matplotlib.pyplot as plt +from matplotlib.figure import Figure +from rasterio.crs import CRS + +from .geo import RasterDataset +from .utils import DatasetNotFoundError, download_url + + +class SouthAmericaSoybean(RasterDataset): + """South America Soybean Dataset. + + This dataset produced annual 30-m soybean maps of South America from 2001 to 2021. + + Link: https://www.nature.com/articles/s41893-021-00729-z + + Dataset contains 2 classes: + + 0. other + 1. soybean + + Dataset Format: + + * 21 .tif files + + + If you use this dataset in your research, please cite the following paper: + + * https://doi.org/10.1038/s41893-021-00729-z + + .. versionadded:: 0.6 + """ + + filename_glob = "South_America_Soybean_*.*" + filename_regex = r"South_America_Soybean_(?P\d{4})" + + date_format = "%Y" + is_image = False + url = "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_{}.tif" + + md5s = { + 2021: "edff3ada13a1a9910d1fe844d28ae4f", + 2020: "0709dec807f576c9707c8c7e183db31", + 2019: "441836493bbcd5e123cff579a58f5a4f", + 2018: "503c2d0a803c2a2629ebbbd9558a3013", + 2017: "4d0487ac1105d171e5f506f1766ea777", + 2016: "770c558f6ac40550d0e264da5e44b3e", + 2015: "6beb96a61fe0e9ce8c06263e500dde8f", + 2014: "824ff91c62a4ba9f4ccfd281729830e5", + 2013: "0263e19b3cae6fdaba4e3b450cef985e", + 2012: "9f3a71097c9836fcff18a13b9ba608b2", + 2011: "b73352ebea3d5658959e9044ec526143", + 2010: "9264532d36ffa93493735a6e44caef0d", + 2009: "341387c1bb42a15140c80702e4cca02d", + 2008: "96fc3f737ab3ce9bcd16cbf7761427e2", + 2007: "bb8549b6674163fe20ffd47ec4ce8903", + 2006: "eabaa525414ecbff89301d3d5c706f0b", + 2005: "89faae27f9b5afbd06935a465e5fe414", + 2004: "f9882ca9c70e054e50172835cb75a8c3", + 2003: "cad5ed461ff4ab45c90177841aaecad2", + 2002: "8a4a9dcea54b3ec7de07657b9f2c0893", + 2001: "2914b0af7590a0ca4dfa9ccefc99020f", + } + + def __init__( + self, + paths: Union[str, Iterable[str]] = "data", + crs: Optional[CRS] = None, + res: Optional[float] = None, + years: list[int] = [2021], + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + cache: bool = True, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new Dataset instance. + + Args: + paths: one or more root directories to search or files to load + crs: :term:`coordinate reference system (CRS)` to warp to + (defaults to the CRS of the first file found) + res: resolution of the dataset in units of CRS + (defaults to the resolution of the first file found) + years: list of years for which to use the South America Soybean layer + transforms: a function/transform that takes an input sample + and returns a transformed version + cache: if True, cache file handle to speed up repeated sampling + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 after downloading files (may be slow) + + Raises: + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + self.paths = paths + self.download = download + self.checksum = checksum + self.years = years + self._verify() + + super().__init__(paths, crs, res, transforms=transforms, cache=cache) + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the extracted files already exist + if self.files: + return + assert isinstance(self.paths, str) + + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset.""" + for year in self.years: + download_url( + self.url.format(year), + self.paths, + md5=self.md5s[year] if self.checksum else None, + ) + + def plot( + self, + sample: dict[str, Any], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`RasterDataset.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + mask = sample["mask"].squeeze() + ncols = 1 + + showing_predictions = "prediction" in sample + if showing_predictions: + pred = sample["prediction"].squeeze() + ncols = 2 + + fig, axs = plt.subplots( + nrows=1, ncols=ncols, figsize=(ncols * 4, 4), squeeze=False + ) + + axs[0, 0].imshow(mask, interpolation="none") + axs[0, 0].axis("off") + + if show_titles: + axs[0, 0].set_title("Mask") + + if showing_predictions: + axs[0, 1].imshow(pred, interpolation="none") + axs[0, 1].axis("off") + if show_titles: + axs[0, 1].set_title("Prediction") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig