From e510a069a75dfd337b11d2c8f8b0734fb42a77d4 Mon Sep 17 00:00:00 2001 From: cookie-kyu Date: Sun, 15 Oct 2023 23:54:09 -0500 Subject: [PATCH 01/72] Created file for South America Soybean dataset and added it to __init__.py --- torchgeo/datasets/__init__.py | 2 ++ torchgeo/datasets/south_america_soybean.py | 1 + 2 files changed, 3 insertions(+) create mode 100644 torchgeo/datasets/south_america_soybean.py diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 8fd3f2ee206..abe83cdb374 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -92,6 +92,7 @@ from .sentinel import Sentinel, Sentinel1, Sentinel2 from .skippd import SKIPPD from .so2sat import So2Sat +from .south_america_soybean import SouthAmericaSoybean from .spacenet import ( SpaceNet, SpaceNet1, @@ -172,6 +173,7 @@ "Sentinel", "Sentinel1", "Sentinel2", + "SouthAmericaSoybean" # NonGeoDataset "ADVANCE", "BeninSmallHolderCashews", diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py new file mode 100644 index 00000000000..38eb0249b27 --- /dev/null +++ b/torchgeo/datasets/south_america_soybean.py @@ -0,0 +1 @@ +"""South America Soybean Dataset""" \ No newline at end of file From c67cae41a2ee85283c93a4a566e38adc40cd1009 Mon Sep 17 00:00:00 2001 From: cookie-kyu Date: Wed, 25 Oct 2023 01:00:15 -0500 Subject: [PATCH 02/72] Updated south_america_soybean.py --- torchgeo/datasets/south_america_soybean.py | 228 ++++++++++++++++++++- 1 file changed, 227 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 38eb0249b27..a4d1edb1ddd 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -1 +1,227 @@ -"""South America Soybean Dataset""" \ No newline at end of file +import glob +import os +from collections.abc import Iterable +from typing import Any, Callable, Optional, Union + +import torch +import matplotlib.pyplot as plt +from matplotlib.figure import Figure +from rasterio.crs import CRS + +from .geo import RasterDataset +from .utils import BoundingBox, download_url + + +class south_america_soybean(RasterDataset): + """South America Soybean Dataset + + Link: https://www.nature.com/articles/s41893-021-00729-z + + Dataset contains 1 classes: + 1: soybean + + Dataset Format: + 1) 21 .tif files + + If you use this dataset in your research, please use the corresponding citation: + Song, XP., Hansen, M.C., Potapov, P. et al. Massive soybean expansion in South America since 2000 and implications for conservation. Nat Sustain 4, 784–792 (2021). https://doi.org/10.1038/s41893-021-00729-z + + """ + filename_glob = "SouthAmerica_Soybean_*.tif" + filename_regex = (r"SouthAmerica_Soybean_(?P\d{4})\.tif") + + zipfile_glob = "" + + date_format = "%Y" + is_image = False + + url = "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_" + + md5s = { + 2001: "2914b0af7590a0ca4dfa9ccefc99020f", + 2002: "8a4a9dcea54b3ec7de07657b9f2c0893", + 2003: "cad5ed461ff4ab45c90177841aaecad2", + 2004: "f9882ca9c70e054e50172835cb75a8c3", + 2005: "89faae27f9b5afbd06935a465e5fe414", + 2006: "eabaa525414ecbff89301d3d5c706f0b", + 2007: "bb8549b6674163fe20ffd47ec4ce8903", + 2008: "96fc3f737ab3ce9bcd16cbf7761427e2", + 2009: "341387c1bb42a15140c80702e4cca02d", + 2010: "9264532d36ffa93493735a6e44caef0d", + 2011: "b73352ebea3d5658959e9044ec526143", + 2012: "9f3a71097c9836fcff18a13b9ba608b2", + 2013: "0263e19b3cae6fdaba4e3b450cef985e", + 2014: "824ff91c62a4ba9f4ccfd281729830e5", + 2015: "6beb96a61fe0e9ce8c06263e500dde8f", + 2016: "770c558f6ac40550d0e264da5e44b3e", + 2017: "4d0487ac1105d171e5f506f1766ea777", + 2018: "503c2d0a803c2a2629ebbbd9558a3013", + 2019: "441836493bbcd5e123cff579a58f5a4f", + 2020: "0709dec807f576c9707c8c7e183db31", + 2021: "edff3ada13a1a9910d1fe844d28ae4f", + + } + + + cmap = { + 0: (0,0,0,0), + 1: (255,0,255,255) + } + + def __init__( + self, + paths: Union[str, Iterable[str]] = "data", + crs: Optional[CRS] = None, + res: Optional[float] = None, + years: list[int] = [2021], + classes: list[int] = list(cmap.keys()), + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + cache: bool = True, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new Dataset instance. + Args: + root: root directory where dataset can be found + crs: :term:`coordinate reference system (CRS)` to warp to + (defaults to the CRS of the first file found) + res: resolution of the dataset in units of CRS + (defaults to the resolution of the first file found) + years: list of years to use + transforms: a function/transform that takes an input sample + and returns a transformed version + cache: if True, cache file handle to speed up repeated sampling + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 after downloading files (may be slow) + Raises: + FileNotFoundError: if no files are found in ``root`` + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + AssertionError: if ``year`` is invalid + """ + assert set(years) <= self.md5s.keys(), ( + "South America Soybean data only exists for the following years: " + f"{list(self.md5s.keys())}." + ) + assert ( + set(classes) <= self.cmap.keys() + ), f"Only the following classes are valid: {list(self.cmap.keys())}." + assert 0 in classes, "Classes must include the background class: 0" + + + self.years = years + self.paths = paths + self.classes = classes + self.download = download + self.checksum = checksum + self.ordinal_map = torch.zeros(max(self.cmap.keys()) + 1, dtype=self.dtype) + self.ordinal_cmap = torch.zeros((len(self.classes), 4), dtype=torch.uint8) + + self._verify() + + super().__init__(paths, crs, res, transforms=transforms, cache=cache) + + for v, k in enumerate(self.classes): + self.ordinal_map[k] = v + self.ordinal_cmap[v] = torch.tensor(self.cmap[k]) + + + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + """Retrieve mask and metadata indexed by query. + Args: + query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index + Returns: + sample of mask and metadata at that index + Raises: + IndexError: if query is not found in the index + """ + sample = super().__getitem__(query) + sample["mask"] = self.ordinal_map[sample["mask"]] + return sample + + def _verify(self) -> None: + """Verify the integrity of the dataset. + Raises: + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + """ + # Check if the extracted files already exist + if self.files: + return + + # Check if the zip files have already been downloaded + exists = False + + assert isinstance(self.paths, str) + + #todo + pathname = os.path.join(self.paths, "**", self.zipfile_glob) + if glob.glob(pathname, recursive=True): + exists = True + self._extract() + + if exists == True: + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + f"Dataset not found in `root={self.paths}` and `download=False`, " + "either specify a different `root` directory or use `download=True` " + "to automatically download the dataset." + ) + + # Download the dataset + self._download() + self._extract() + def _download(self) -> None: + """Download the dataset.""" + for i in range(21): + ext = ".tif" + downloadUrl = self.url + str(i+2001) + ext + download_url(downloadUrl,self.paths,md5 = self.md5s if self.checksum else None) + + + def plot( + self, + sample: dict[str, Any], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> Figure: + """Plot a sample from the dataset. + Args: + sample: a sample returned by :meth:`RasterDataset.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + Returns: + a matplotlib Figure with the rendered sample + """ + mask = sample["mask"].squeeze() + ncols = 1 + + showing_predictions = "prediction" in sample + if showing_predictions: + pred = sample["prediction"].squeeze() + ncols = 2 + + fig, axs = plt.subplots( + nrows=1, ncols=ncols, figsize=(ncols * 4, 4), squeeze=False + ) + + axs[0, 0].imshow(self.ordinal_cmap[mask], interpolation="none") + axs[0, 0].axis("off") + + if show_titles: + axs[0, 0].set_title("Mask") + + if showing_predictions: + axs[0, 1].imshow(self.ordinal_cmap[pred], interpolation="none") + axs[0, 1].axis("off") + if show_titles: + axs[0, 1].set_title("Prediction") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig + + + From 79e697090511d145ae058a4de4b7180acd104584 Mon Sep 17 00:00:00 2001 From: cookie-kyu Date: Wed, 8 Nov 2023 11:17:24 -0600 Subject: [PATCH 03/72] Added tests --- .github/workflows/release.yaml | 6 +- .github/workflows/style.yaml | 10 +- .github/workflows/tests.yaml | 8 +- .github/workflows/tutorials.yaml | 2 +- .pre-commit-config.yaml | 10 +- README.md | 15 +++ docs/api/datasets.rst | 7 +- docs/api/geo_datasets.csv | 1 + docs/api/non_geo_datasets.csv | 2 +- experiments/ssl4eo/flops.py | 2 +- experiments/ssl4eo/landsat/README.md | 2 +- .../ssl4eo/landsat/plot_landsat_bands.py | 2 +- .../ssl4eo/landsat/plot_landsat_timeline.py | 2 +- experiments/torchgeo/plot_bar_chart.py | 2 +- .../torchgeo/plot_dataloader_benchmark.py | 2 +- .../torchgeo/plot_percentage_benchmark.py | 2 +- pyproject.toml | 3 + requirements/required.txt | 12 +- requirements/style.txt | 2 +- requirements/tests.txt | 6 +- tests/.DS_Store | Bin 0 -> 6148 bytes ...eground_Live_Woody_Biomass_Density.geojson | 2 +- tests/data/agb_live_woody_density/data.py | 4 +- tests/data/south_america_soybean/.DS_Store | Bin 0 -> 6148 bytes tests/data/south_america_soybean/data.py | 65 +++++++++++ tests/datamodules/test_levircd.py | 78 +++++++++++++ tests/datamodules/test_oscd.py | 45 ++++--- tests/datasets/test_advance.py | 4 +- tests/datasets/test_agb_live_woody_density.py | 3 +- tests/datasets/test_astergdem.py | 10 +- tests/datasets/test_benin_cashews.py | 4 +- tests/datasets/test_bigearthnet.py | 7 +- tests/datasets/test_biomassters.py | 5 +- tests/datasets/test_cbf.py | 3 +- tests/datasets/test_cdl.py | 10 +- tests/datasets/test_chesapeake.py | 5 +- tests/datasets/test_cloud_cover.py | 4 +- tests/datasets/test_cms_mangrove_canopy.py | 9 +- tests/datasets/test_cowc.py | 7 +- tests/datasets/test_cv4a_kenya_crop_type.py | 4 +- tests/datasets/test_cyclone.py | 4 +- tests/datasets/test_deepglobelandcover.py | 9 +- tests/datasets/test_dfc2022.py | 4 +- tests/datasets/test_eddmaps.py | 10 +- tests/datasets/test_enviroatlas.py | 3 +- tests/datasets/test_esri2020.py | 10 +- tests/datasets/test_etci2021.py | 4 +- tests/datasets/test_eudem.py | 10 +- tests/datasets/test_eurosat.py | 7 +- tests/datasets/test_fair1m.py | 4 +- tests/datasets/test_fire_risk.py | 4 +- tests/datasets/test_forestdamage.py | 4 +- tests/datasets/test_gbif.py | 10 +- tests/datasets/test_geo.py | 26 ++++- tests/datasets/test_gid15.py | 4 +- tests/datasets/test_globbiomass.py | 3 +- tests/datasets/test_idtrees.py | 8 +- tests/datasets/test_inaturalist.py | 3 +- tests/datasets/test_inria.py | 4 +- tests/datasets/test_l7irish.py | 10 +- tests/datasets/test_l8biome.py | 10 +- tests/datasets/test_landcoverai.py | 11 +- tests/datasets/test_landsat.py | 10 +- tests/datasets/test_levircd.py | 10 +- tests/datasets/test_loveda.py | 6 +- tests/datasets/test_mapinwild.py | 4 +- tests/datasets/test_millionaid.py | 4 +- tests/datasets/test_naip.py | 10 +- tests/datasets/test_nasa_marine_debris.py | 7 +- tests/datasets/test_nlcd.py | 10 +- tests/datasets/test_openbuildings.py | 22 +--- tests/datasets/test_oscd.py | 26 +++-- tests/datasets/test_pastis.py | 4 +- tests/datasets/test_patternnet.py | 7 +- tests/datasets/test_potsdam.py | 4 +- tests/datasets/test_reforestree.py | 4 +- tests/datasets/test_resisc45.py | 7 +- tests/datasets/test_rwanda_field_boundary.py | 4 +- tests/datasets/test_seasonet.py | 4 +- tests/datasets/test_seco.py | 4 +- tests/datasets/test_sen12ms.py | 6 +- tests/datasets/test_sentinel.py | 5 +- tests/datasets/test_skippd.py | 4 +- tests/datasets/test_so2sat.py | 4 +- tests/datasets/test_south_america_soybean.py | 110 ++++++++++++++++++ tests/datasets/test_spacenet.py | 13 ++- tests/datasets/test_ssl4eo.py | 6 +- tests/datasets/test_ssl4eo_benchmark.py | 10 +- .../datasets/test_sustainbench_crop_yield.py | 4 +- tests/datasets/test_ucmerced.py | 7 +- tests/datasets/test_usavars.py | 4 +- tests/datasets/test_utils.py | 54 ++++++++- tests/datasets/test_vaihingen.py | 4 +- tests/datasets/test_vhr10.py | 4 +- .../test_western_usa_live_fuel_moisture.py | 4 +- tests/datasets/test_xview2.py | 4 +- tests/datasets/test_zuericrop.py | 7 +- torchgeo/datamodules/__init__.py | 2 + torchgeo/datamodules/eurosat.py | 80 +++++++------ torchgeo/datamodules/levircd.py | 70 +++++++++++ torchgeo/datamodules/oscd.py | 82 ++++++------- torchgeo/datasets/__init__.py | 3 + torchgeo/datasets/advance.py | 16 +-- torchgeo/datasets/agb_live_woody_density.py | 23 +--- torchgeo/datasets/astergdem.py | 16 +-- torchgeo/datasets/benin_cashews.py | 14 ++- torchgeo/datasets/bigearthnet.py | 22 ++-- torchgeo/datasets/biomassters.py | 7 +- torchgeo/datasets/cbf.py | 11 +- torchgeo/datasets/cdl.py | 17 +-- torchgeo/datasets/chesapeake.py | 32 ++--- torchgeo/datasets/cloud_cover.py | 17 ++- torchgeo/datasets/cms_mangrove_canopy.py | 17 +-- torchgeo/datasets/cowc.py | 10 +- torchgeo/datasets/cv4a_kenya_crop_type.py | 17 ++- torchgeo/datasets/cyclone.py | 17 ++- torchgeo/datasets/deepglobelandcover.py | 16 +-- torchgeo/datasets/dfc2022.py | 20 ++-- torchgeo/datasets/eddmaps.py | 6 +- torchgeo/datasets/enviroatlas.py | 17 +-- torchgeo/datasets/esri2020.py | 17 +-- torchgeo/datasets/etci2021.py | 16 +-- torchgeo/datasets/eudem.py | 16 +-- torchgeo/datasets/eurosat.py | 47 +++----- torchgeo/datasets/fair1m.py | 23 +--- torchgeo/datasets/fire_risk.py | 16 +-- torchgeo/datasets/forestdamage.py | 22 ++-- torchgeo/datasets/gbif.py | 6 +- torchgeo/datasets/geo.py | 33 +++--- torchgeo/datasets/gid15.py | 16 +-- torchgeo/datasets/globbiomass.py | 17 +-- torchgeo/datasets/idtrees.py | 15 +-- torchgeo/datasets/inaturalist.py | 6 +- torchgeo/datasets/inria.py | 15 +-- torchgeo/datasets/l7irish.py | 17 +-- torchgeo/datasets/l8biome.py | 17 +-- torchgeo/datasets/landcoverai.py | 29 ++--- torchgeo/datasets/landsat.py | 2 +- torchgeo/datasets/levircd.py | 44 +++---- torchgeo/datasets/loveda.py | 19 +-- torchgeo/datasets/mapinwild.py | 11 +- torchgeo/datasets/millionaid.py | 10 +- torchgeo/datasets/nasa_marine_debris.py | 22 ++-- torchgeo/datasets/nlcd.py | 17 +-- torchgeo/datasets/openbuildings.py | 28 +---- torchgeo/datasets/oscd.py | 66 +++++++---- torchgeo/datasets/pastis.py | 17 +-- torchgeo/datasets/patternnet.py | 17 +-- torchgeo/datasets/potsdam.py | 17 ++- torchgeo/datasets/reforestree.py | 28 ++--- torchgeo/datasets/resisc45.py | 65 +---------- torchgeo/datasets/rwanda_field_boundary.py | 28 ++--- torchgeo/datasets/seasonet.py | 23 ++-- torchgeo/datasets/seco.py | 22 ++-- torchgeo/datasets/sen12ms.py | 14 +-- torchgeo/datasets/sentinel.py | 4 +- torchgeo/datasets/skippd.py | 22 +--- torchgeo/datasets/so2sat.py | 6 +- torchgeo/datasets/south_america_soybean.py | 2 +- torchgeo/datasets/spacenet.py | 35 ++---- torchgeo/datasets/ssl4eo.py | 26 ++--- torchgeo/datasets/ssl4eo_benchmark.py | 16 +-- torchgeo/datasets/sustainbench_crop_yield.py | 22 +--- torchgeo/datasets/ucmerced.py | 45 ++----- torchgeo/datasets/usavars.py | 17 +-- torchgeo/datasets/utils.py | 69 +++++++++++ torchgeo/datasets/vaihingen.py | 17 ++- torchgeo/datasets/vhr10.py | 15 +-- .../western_usa_live_fuel_moisture.py | 23 ++-- torchgeo/datasets/xview.py | 23 ++-- torchgeo/datasets/zuericrop.py | 18 +-- torchgeo/trainers/base.py | 12 -- torchgeo/trainers/classification.py | 2 + torchgeo/trainers/detection.py | 1 + torchgeo/trainers/regression.py | 1 + torchgeo/trainers/segmentation.py | 1 + torchgeo/transforms/transforms.py | 2 +- 177 files changed, 1371 insertions(+), 1244 deletions(-) create mode 100644 tests/.DS_Store create mode 100644 tests/data/south_america_soybean/.DS_Store create mode 100644 tests/data/south_america_soybean/data.py create mode 100644 tests/datamodules/test_levircd.py create mode 100644 tests/datasets/test_south_america_soybean.py create mode 100644 torchgeo/datamodules/levircd.py diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 4333e29cfb9..64fee06d844 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.0 + uses: actions/checkout@v4.1.1 - name: Set up python uses: actions/setup-python@v4.7.1 with: @@ -40,7 +40,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.0 + uses: actions/checkout@v4.1.1 - name: Set up python uses: actions/setup-python@v4.7.1 with: @@ -68,7 +68,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.0 + uses: actions/checkout@v4.1.1 - name: Set up python uses: actions/setup-python@v4.7.1 with: diff --git a/.github/workflows/style.yaml b/.github/workflows/style.yaml index 5359e2fd18b..5cccfd96032 100644 --- a/.github/workflows/style.yaml +++ b/.github/workflows/style.yaml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.0 + uses: actions/checkout@v4.1.1 - name: Set up python uses: actions/setup-python@v4.7.1 with: @@ -39,7 +39,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.0 + uses: actions/checkout@v4.1.1 - name: Set up python uses: actions/setup-python@v4.7.1 with: @@ -64,7 +64,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.0 + uses: actions/checkout@v4.1.1 - name: Set up python uses: actions/setup-python@v4.7.1 with: @@ -89,7 +89,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.0 + uses: actions/checkout@v4.1.1 - name: Set up python uses: actions/setup-python@v4.7.1 with: @@ -114,7 +114,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.0 + uses: actions/checkout@v4.1.1 - name: Set up python uses: actions/setup-python@v4.7.1 with: diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index f34d52a2f6e..a078180b0c1 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.0 + uses: actions/checkout@v4.1.1 - name: Set up python uses: actions/setup-python@v4.7.1 with: @@ -45,7 +45,7 @@ jobs: python-version: ['3.9', '3.10', '3.11'] steps: - name: Clone repo - uses: actions/checkout@v4.1.0 + uses: actions/checkout@v4.1.1 - name: Set up python uses: actions/setup-python@v4.7.1 with: @@ -68,7 +68,7 @@ jobs: run: brew install rar if: ${{ runner.os == 'macOS' }} - name: Install choco dependencies (Windows) - run: choco install unrar + run: choco install 7zip if: ${{ runner.os == 'Windows' }} - name: Install pip dependencies if: steps.cache.outputs.cache-hit != 'true' @@ -92,7 +92,7 @@ jobs: MPLBACKEND: Agg steps: - name: Clone repo - uses: actions/checkout@v4.1.0 + uses: actions/checkout@v4.1.1 - name: Set up python uses: actions/setup-python@v4.7.1 with: diff --git a/.github/workflows/tutorials.yaml b/.github/workflows/tutorials.yaml index 621ce1064fe..8cb3c344f04 100644 --- a/.github/workflows/tutorials.yaml +++ b/.github/workflows/tutorials.yaml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.0 + uses: actions/checkout@v4.1.1 - name: Set up python uses: actions/setup-python@v4.7.1 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index da8e338b2c2..044afa0738c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/asottile/pyupgrade - rev: v3.3.1 + rev: v3.15.0 hooks: - id: pyupgrade args: [--py39-plus] @@ -12,13 +12,13 @@ repos: additional_dependencies: ['.[colors]'] - repo: https://github.com/psf/black - rev: 23.1.0 + rev: 23.10.1 hooks: - id: black args: [--skip-magic-trailing-comma] - repo: https://github.com/pycqa/flake8.git - rev: 6.0.0 + rev: 6.1.0 hooks: - id: flake8 @@ -30,9 +30,9 @@ repos: additional_dependencies: ['.[toml]'] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.0.1 + rev: v1.6.1 hooks: - id: mypy args: [--strict, --ignore-missing-imports, --show-error-codes] - additional_dependencies: [torch>=2, torchmetrics>=0.10, lightning>=2.0.9, pytest>=6.1.2, pyvista>=0.34.2, kornia>=0.6.5, numpy>=1.22] + additional_dependencies: [kornia>=0.6.5, lightning>=2.0.9, matplotlib>=3.8.1, numpy>=1.22, pytest>=6.1.2, pyvista>=0.34.2, torch>=2, torchmetrics>=0.10] exclude: (build|data|dist|logo|logs|output)/ diff --git a/README.md b/README.md index eda6fc956c3..279a72fbf5f 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,21 @@ for batch in dataloader: All TorchGeo datasets are compatible with PyTorch data loaders, making them easy to integrate into existing training workflows. The only difference between a benchmark dataset in TorchGeo and a similar dataset in torchvision is that each dataset returns a dictionary with keys for each PyTorch `Tensor`. +### Pre-trained Weights + +Pre-trained weights have proven to be tremendously beneficial for transfer learning tasks in computer vision. Practitioners usually utilize models pre-trained on the ImageNet dataset, containing RGB images. However, remote sensing data often goes beyond RGB with additional multispectral channels that can vary across sensors. TorchGeo is the first library to support models pre-trained on different multispectral sensors, and adopts torchvision's [multi-weight API](https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/). A summary of currently available weights can be seen in the [docs](https://torchgeo.readthedocs.io/en/stable/api/models.html#pretrained-weights). To create a [timm](https://github.com/huggingface/pytorch-image-models) Resnet-18 model with weights that have been pretrained on Sentinel-2 imagery, you can do the following: + +```python +import timm +from torchgeo.models import ResNet18_Weights + +weights = ResNet18_Weights.SENTINEL2_ALL_MOCO +model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"], num_classes=10) +model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False) +``` + +These weights can also directly be used in TorchGeo Lightning modules that are shown in the following section via the `weights` argument. For a notebook example, see this [tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/pretrained_weights.html). + ### Reproducibility with Lightning In order to facilitate direct comparisons between results published in the literature and further reduce the boilerplate code needed to run experiments with datasets in TorchGeo, we have created Lightning [*datamodules*](https://torchgeo.readthedocs.io/en/stable/api/datamodules.html) with well-defined train-val-test splits and [*trainers*](https://torchgeo.readthedocs.io/en/stable/api/trainers.html) for various tasks like classification, regression, and semantic segmentation. These datamodules show how to incorporate augmentations from the kornia library, include preprocessing transforms (with pre-calculated channel statistics), and let users easily experiment with hyperparameters related to the data itself (as opposed to the modeling process). Training a semantic segmentation model on the [Inria Aerial Image Labeling](https://project.inria.fr/aerialimagelabeling/) dataset is as easy as a few imports and four lines of code. diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 5753d529a84..5d81d11f1c6 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -178,7 +178,7 @@ BioMassters ^^^^^^^^^^^ .. autoclass:: BioMassters - + Cloud Cover Detection ^^^^^^^^^^^^^^^^^^^^^ @@ -464,3 +464,8 @@ Splitting Functions .. autofunction:: random_grid_cell_assignment .. autofunction:: roi_split .. autofunction:: time_series_split + +Errors +------ + +.. autoclass:: DatasetNotFoundError diff --git a/docs/api/geo_datasets.csv b/docs/api/geo_datasets.csv index 54cf53b9f27..a8fd3e26742 100644 --- a/docs/api/geo_datasets.csv +++ b/docs/api/geo_datasets.csv @@ -20,3 +20,4 @@ Dataset,Type,Source,Size (px),Resolution (m) `NLCD`_,Masks,Landsat,-,30 `Open Buildings`_,Geometries,"Maxar, CNES/Airbus",-,- `Sentinel`_,Imagery,Sentinel,"10,000x10,000",10 +`South America Soybean`_,Masks,Sentinel-2,-,10 diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv index db673e0c731..8e132d08e0d 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/non_geo_datasets.csv @@ -2,7 +2,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands `ADVANCE`_,C,"Google Earth, Freesound","5,075",13,512x512,0.5,RGB `Benin Cashew Plantations`_,S,Airbus Pléiades,70,6,"1,122x1,186",10,MSI `BigEarthNet`_,C,Sentinel-1/2,"590,326",19--43,120x120,10,"SAR, MSI" -`BioMassters`_,R,Sentinel-1/2 and Lidar,,256, 10, "SAR, MSI" +`BioMassters`_,R,Sentinel-1/2 and Lidar,,,256x256, 10, "SAR, MSI" `Cloud Cover Detection`_,S,Sentinel-2,"22,728",2,512x512,10,MSI `COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","388,435",2,256x256,0.15,RGB `Kenya Crop Type`_,S,Sentinel-2,"4,688",7,"3,035x2,016",10,MSI diff --git a/experiments/ssl4eo/flops.py b/experiments/ssl4eo/flops.py index 985a1e72fc1..6bffb1835f9 100755 --- a/experiments/ssl4eo/flops.py +++ b/experiments/ssl4eo/flops.py @@ -22,7 +22,7 @@ # Calculate memory requirements of model mem_params = sum([p.nelement() * p.element_size() for p in m.parameters()]) mem_bufs = sum([b.nelement() * b.element_size() for b in m.buffers()]) - mem = (mem_params + mem_bufs) / 2**20 + mem = (mem_params + mem_bufs) / 1000000 print(f"Memory: {mem:.2f} MB") with get_accelerator().device(0): diff --git a/experiments/ssl4eo/landsat/README.md b/experiments/ssl4eo/landsat/README.md index 9ceeab439e3..d681ea28986 100644 --- a/experiments/ssl4eo/landsat/README.md +++ b/experiments/ssl4eo/landsat/README.md @@ -89,7 +89,7 @@ This will create patches of NLCD and CDL data with the same locations and dimens Using either the newly created datasets or after downloading the datasets from Hugging Face, you can run each experiment using: ```console -$ torchgeo --config *.yaml +$ torchgeo fit --config *.yaml ``` The config files to be passed can be found in the `conf/` directory. Feel free to tweak any hyperparameters you see in these files. The default values are the optimal hyperparameters we found. diff --git a/experiments/ssl4eo/landsat/plot_landsat_bands.py b/experiments/ssl4eo/landsat/plot_landsat_bands.py index 591bfeba49a..edfd6c86ec7 100755 --- a/experiments/ssl4eo/landsat/plot_landsat_bands.py +++ b/experiments/ssl4eo/landsat/plot_landsat_bands.py @@ -163,4 +163,4 @@ plt.tight_layout() plt.subplots_adjust(wspace=0.05) -plt.show() # type: ignore[no-untyped-call] +plt.show() diff --git a/experiments/ssl4eo/landsat/plot_landsat_timeline.py b/experiments/ssl4eo/landsat/plot_landsat_timeline.py index ca043500755..94eb40dbe0e 100755 --- a/experiments/ssl4eo/landsat/plot_landsat_timeline.py +++ b/experiments/ssl4eo/landsat/plot_landsat_timeline.py @@ -141,4 +141,4 @@ ax.spines[["top", "right"]].set_visible(False) plt.tight_layout() -plt.show() # type: ignore[no-untyped-call] +plt.show() diff --git a/experiments/torchgeo/plot_bar_chart.py b/experiments/torchgeo/plot_bar_chart.py index 6eac7b8bb98..9f421183789 100755 --- a/experiments/torchgeo/plot_bar_chart.py +++ b/experiments/torchgeo/plot_bar_chart.py @@ -74,4 +74,4 @@ plt.gca().spines.right.set_visible(False) plt.gca().spines.top.set_visible(False) plt.tight_layout() -plt.show() # type: ignore[no-untyped-call] +plt.show() diff --git a/experiments/torchgeo/plot_dataloader_benchmark.py b/experiments/torchgeo/plot_dataloader_benchmark.py index cc77ce4cf86..4c313ca174f 100755 --- a/experiments/torchgeo/plot_dataloader_benchmark.py +++ b/experiments/torchgeo/plot_dataloader_benchmark.py @@ -40,4 +40,4 @@ plt.gca().spines.right.set_visible(False) plt.gca().spines.top.set_visible(False) plt.tight_layout() -plt.show() # type: ignore[no-untyped-call] +plt.show() diff --git a/experiments/torchgeo/plot_percentage_benchmark.py b/experiments/torchgeo/plot_percentage_benchmark.py index a1b69160e6f..e0f2aa3b0e8 100755 --- a/experiments/torchgeo/plot_percentage_benchmark.py +++ b/experiments/torchgeo/plot_percentage_benchmark.py @@ -53,4 +53,4 @@ ax.set_xlabel("batch size") ax.set_ylabel("% sampling rate (patches/sec)") ax.legend() -plt.show() # type: ignore[no-untyped-call] +plt.show() diff --git a/pyproject.toml b/pyproject.toml index 0f3dc47652a..59982abe5ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -212,6 +212,9 @@ filterwarnings = [ # https://github.com/pytorch/pytorch/pull/69823 "ignore:distutils Version classes are deprecated. Use packaging.version instead:DeprecationWarning", "ignore:The distutils package is deprecated and slated for removal in Python 3.12:DeprecationWarning:torch.utils.tensorboard", + # https://github.com/Lightning-AI/torchmetrics/issues/2121 + # https://github.com/Lightning-AI/torchmetrics/pull/2137 + "ignore:The distutils package is deprecated and slated for removal in Python 3.12:DeprecationWarning:torchmetrics.utilities.imports", # https://github.com/Lightning-AI/lightning/issues/13256 # https://github.com/Lightning-AI/lightning/pull/13261 "ignore:torch.distributed._sharded_tensor will be deprecated:DeprecationWarning:torch.distributed._sharded_tensor", diff --git a/requirements/required.txt b/requirements/required.txt index 3b3b3c33132..d0cbe7e97f5 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -7,13 +7,13 @@ fiona==1.9.5 kornia==0.7.0 lightly==1.4.21 lightning[pytorch-extra]==2.1.0 -matplotlib==3.8.0 -numpy==1.26.0 -pandas==2.1.1 -pillow==10.0.1 +matplotlib==3.8.1 +numpy==1.26.1 +pandas==2.1.2 +pillow==10.1.0 pyproj==3.6.1 -rasterio==1.3.8.post2 -rtree==1.0.1 +rasterio==1.3.9 +rtree==1.1.0 segmentation-models-pytorch==0.3.3 shapely==2.0.2 timm==0.9.2 diff --git a/requirements/style.txt b/requirements/style.txt index 09c97aadf96..3412e7e0f85 100644 --- a/requirements/style.txt +++ b/requirements/style.txt @@ -1,5 +1,5 @@ # style -black[jupyter]==23.9.1 +black[jupyter]==23.10.1 flake8==6.1.0 isort[colors]==5.12.0 pydocstyle[toml]==6.3.0 diff --git a/requirements/tests.txt b/requirements/tests.txt index f21274e380c..1cf4888beec 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,5 +1,5 @@ # tests -mypy==1.6.0 -nbmake==1.4.5 -pytest==7.4.2 +mypy==1.6.1 +nbmake==1.4.6 +pytest==7.4.3 pytest-cov==4.1.0 diff --git a/tests/.DS_Store b/tests/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 GIT binary patch literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0 None: json.dump(base_file, f) for i in base_file["features"]: - filepath = os.path.basename(i["properties"]["download"]) + filepath = os.path.basename(i["properties"]["Mg_px_1_download"]) create_file(path=filepath, dtype="int32", num_channels=1) diff --git a/tests/data/south_america_soybean/.DS_Store b/tests/data/south_america_soybean/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 GIT binary patch literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0 None: + shutil.copy(url, root) + + +class TestLEVIRCDPlusDataModule: + @pytest.fixture + def datamodule( + self, monkeypatch: MonkeyPatch, tmp_path: Path + ) -> LEVIRCDPlusDataModule: + monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) + md5 = "1adf156f628aa32fb2e8fe6cada16c04" + monkeypatch.setattr(LEVIRCDPlus, "md5", md5) + url = os.path.join("tests", "data", "levircd", "LEVIR-CD+.zip") + monkeypatch.setattr(LEVIRCDPlus, "url", url) + + root = str(tmp_path) + dm = LEVIRCDPlusDataModule( + root=root, download=True, num_workers=0, checksum=True, val_split_pct=0.5 + ) + dm.prepare_data() + dm.trainer = Trainer(accelerator="cpu", max_epochs=1) + return dm + + def test_train_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: + datamodule.setup("fit") + if datamodule.trainer: + datamodule.trainer.training = True + batch = next(iter(datamodule.train_dataloader())) + batch = datamodule.on_after_batch_transfer(batch, 0) + assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) + assert batch["image1"].shape[0] == batch["mask"].shape[0] == 8 + assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) + assert batch["image2"].shape[0] == batch["mask"].shape[0] == 8 + assert batch["image1"].shape[1] == 3 + assert batch["image2"].shape[1] == 3 + + def test_val_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: + datamodule.setup("validate") + if datamodule.trainer: + datamodule.trainer.validating = True + batch = next(iter(datamodule.val_dataloader())) + batch = datamodule.on_after_batch_transfer(batch, 0) + if datamodule.val_split_pct > 0.0: + assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) + assert batch["image1"].shape[0] == batch["mask"].shape[0] == 8 + assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) + assert batch["image2"].shape[0] == batch["mask"].shape[0] == 8 + assert batch["image1"].shape[1] == 3 + assert batch["image2"].shape[1] == 3 + + def test_test_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: + datamodule.setup("test") + if datamodule.trainer: + datamodule.trainer.testing = True + batch = next(iter(datamodule.test_dataloader())) + batch = datamodule.on_after_batch_transfer(batch, 0) + assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) + assert batch["image1"].shape[0] == batch["mask"].shape[0] == 8 + assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) + assert batch["image2"].shape[0] == batch["mask"].shape[0] == 8 + assert batch["image1"].shape[1] == 3 + assert batch["image2"].shape[1] == 3 diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py index 10e890d044b..0c009a8ec2f 100644 --- a/tests/datamodules/test_oscd.py +++ b/tests/datamodules/test_oscd.py @@ -8,10 +8,11 @@ from lightning.pytorch import Trainer from torchgeo.datamodules import OSCDDataModule +from torchgeo.datasets import OSCD class TestOSCDDataModule: - @pytest.fixture(params=["all", "rgb"]) + @pytest.fixture(params=[OSCD.all_bands, OSCD.rgb_bands]) def datamodule(self, request: SubRequest) -> OSCDDataModule: bands = request.param root = os.path.join("tests", "data", "oscd") @@ -34,12 +35,16 @@ def test_train_dataloader(self, datamodule: OSCDDataModule) -> None: datamodule.trainer.training = True batch = next(iter(datamodule.train_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) - assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) - assert batch["image"].shape[0] == batch["mask"].shape[0] == 1 - if datamodule.bands == "all": - assert batch["image"].shape[1] == 26 + assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1 + assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1 + if datamodule.bands == OSCD.all_bands: + assert batch["image1"].shape[1] == 13 + assert batch["image2"].shape[1] == 13 else: - assert batch["image"].shape[1] == 6 + assert batch["image1"].shape[1] == 3 + assert batch["image2"].shape[1] == 3 def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: datamodule.setup("validate") @@ -48,12 +53,16 @@ def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: batch = next(iter(datamodule.val_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) if datamodule.val_split_pct > 0.0: - assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) - assert batch["image"].shape[0] == batch["mask"].shape[0] == 1 - if datamodule.bands == "all": - assert batch["image"].shape[1] == 26 + assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1 + assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1 + if datamodule.bands == OSCD.all_bands: + assert batch["image1"].shape[1] == 13 + assert batch["image2"].shape[1] == 13 else: - assert batch["image"].shape[1] == 6 + assert batch["image1"].shape[1] == 3 + assert batch["image2"].shape[1] == 3 def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: datamodule.setup("test") @@ -61,9 +70,13 @@ def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: datamodule.trainer.testing = True batch = next(iter(datamodule.test_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) - assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) - assert batch["image"].shape[0] == batch["mask"].shape[0] == 1 - if datamodule.bands == "all": - assert batch["image"].shape[1] == 26 + assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1 + assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1 + if datamodule.bands == OSCD.all_bands: + assert batch["image1"].shape[1] == 13 + assert batch["image2"].shape[1] == 13 else: - assert batch["image"].shape[1] == 6 + assert batch["image1"].shape[1] == 3 + assert batch["image2"].shape[1] == 3 diff --git a/tests/datasets/test_advance.py b/tests/datasets/test_advance.py index 9da0976e3ab..bcbaba2500f 100644 --- a/tests/datasets/test_advance.py +++ b/tests/datasets/test_advance.py @@ -14,7 +14,7 @@ from pytest import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import ADVANCE +from torchgeo.datasets import ADVANCE, DatasetNotFoundError def download_url(url: str, root: str, *args: str) -> None: @@ -68,7 +68,7 @@ def test_already_downloaded(self, dataset: ADVANCE) -> None: ADVANCE(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): ADVANCE(str(tmp_path)) def test_mock_missing_module( diff --git a/tests/datasets/test_agb_live_woody_density.py b/tests/datasets/test_agb_live_woody_density.py index ae775a1cabf..3e0bbbc2dc7 100644 --- a/tests/datasets/test_agb_live_woody_density.py +++ b/tests/datasets/test_agb_live_woody_density.py @@ -15,6 +15,7 @@ import torchgeo from torchgeo.datasets import ( AbovegroundLiveWoodyBiomassDensity, + DatasetNotFoundError, IntersectionDataset, UnionDataset, ) @@ -53,7 +54,7 @@ def test_getitem(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None: assert isinstance(x["mask"], torch.Tensor) def test_no_dataset(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): AbovegroundLiveWoodyBiomassDensity(str(tmp_path)) def test_already_downloaded( diff --git a/tests/datasets/test_astergdem.py b/tests/datasets/test_astergdem.py index 0a1d8fc263a..dfd41e40409 100644 --- a/tests/datasets/test_astergdem.py +++ b/tests/datasets/test_astergdem.py @@ -11,7 +11,13 @@ import torch.nn as nn from rasterio.crs import CRS -from torchgeo.datasets import AsterGDEM, BoundingBox, IntersectionDataset, UnionDataset +from torchgeo.datasets import ( + AsterGDEM, + BoundingBox, + DatasetNotFoundError, + IntersectionDataset, + UnionDataset, +) class TestAsterGDEM: @@ -26,7 +32,7 @@ def dataset(self, tmp_path: Path) -> AsterGDEM: def test_datasetmissing(self, tmp_path: Path) -> None: shutil.rmtree(tmp_path) os.makedirs(tmp_path) - with pytest.raises(RuntimeError, match="Dataset not found in"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): AsterGDEM(str(tmp_path)) def test_getitem(self, dataset: AsterGDEM) -> None: diff --git a/tests/datasets/test_benin_cashews.py b/tests/datasets/test_benin_cashews.py index 7255e491b8a..6d81d6876be 100644 --- a/tests/datasets/test_benin_cashews.py +++ b/tests/datasets/test_benin_cashews.py @@ -13,7 +13,7 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset -from torchgeo.datasets import BeninSmallHolderCashews +from torchgeo.datasets import BeninSmallHolderCashews, DatasetNotFoundError class Collection: @@ -73,7 +73,7 @@ def test_already_downloaded(self, dataset: BeninSmallHolderCashews) -> None: BeninSmallHolderCashews(root=dataset.root, download=True, api_key="") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): BeninSmallHolderCashews(str(tmp_path)) def test_invalid_bands(self) -> None: diff --git a/tests/datasets/test_bigearthnet.py b/tests/datasets/test_bigearthnet.py index 7dfea2548cb..a0e93952244 100644 --- a/tests/datasets/test_bigearthnet.py +++ b/tests/datasets/test_bigearthnet.py @@ -13,7 +13,7 @@ from pytest import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import BigEarthNet +from torchgeo.datasets import BigEarthNet, DatasetNotFoundError def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -134,10 +134,7 @@ def test_already_downloaded_not_extracted( ) def test_not_downloaded(self, tmp_path: Path) -> None: - err = "Dataset not found in `root` directory and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - with pytest.raises(RuntimeError, match=err): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): BigEarthNet(str(tmp_path)) def test_plot(self, dataset: BigEarthNet) -> None: diff --git a/tests/datasets/test_biomassters.py b/tests/datasets/test_biomassters.py index 51225f19c47..17dab2df03c 100644 --- a/tests/datasets/test_biomassters.py +++ b/tests/datasets/test_biomassters.py @@ -10,7 +10,7 @@ import pytest from _pytest.fixtures import SubRequest -from torchgeo.datasets import BioMassters +from torchgeo.datasets import BioMassters, DatasetNotFoundError class TestBioMassters: @@ -36,8 +36,7 @@ def test_invalid_bands(self, dataset: BioMassters) -> None: BioMassters(dataset.root, sensors=["S3"]) def test_not_downloaded(self, tmp_path: Path) -> None: - match = "Dataset not found" - with pytest.raises(RuntimeError, match=match): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): BioMassters(str(tmp_path)) def test_plot(self, dataset: BioMassters) -> None: diff --git a/tests/datasets/test_cbf.py b/tests/datasets/test_cbf.py index e90a309aeaf..f53023925b5 100644 --- a/tests/datasets/test_cbf.py +++ b/tests/datasets/test_cbf.py @@ -16,6 +16,7 @@ from torchgeo.datasets import ( BoundingBox, CanadianBuildingFootprints, + DatasetNotFoundError, IntersectionDataset, UnionDataset, ) @@ -75,7 +76,7 @@ def test_plot_prediction(self, dataset: CanadianBuildingFootprints) -> None: dataset.plot(x, suptitle="Prediction") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): CanadianBuildingFootprints(str(tmp_path)) def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None: diff --git a/tests/datasets/test_cdl.py b/tests/datasets/test_cdl.py index 50babc4c175..47d0beb8d6a 100644 --- a/tests/datasets/test_cdl.py +++ b/tests/datasets/test_cdl.py @@ -15,7 +15,13 @@ from rasterio.crs import CRS import torchgeo.datasets.utils -from torchgeo.datasets import CDL, BoundingBox, IntersectionDataset, UnionDataset +from torchgeo.datasets import ( + CDL, + BoundingBox, + DatasetNotFoundError, + IntersectionDataset, + UnionDataset, +) def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -111,7 +117,7 @@ def test_plot_prediction(self, dataset: CDL) -> None: plt.close() def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): CDL(str(tmp_path)) def test_invalid_query(self, dataset: CDL) -> None: diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py index a659e039ca0..0692e0be00f 100644 --- a/tests/datasets/test_chesapeake.py +++ b/tests/datasets/test_chesapeake.py @@ -18,6 +18,7 @@ BoundingBox, Chesapeake13, ChesapeakeCVPR, + DatasetNotFoundError, IntersectionDataset, UnionDataset, ) @@ -70,7 +71,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: Chesapeake13(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): Chesapeake13(str(tmp_path), checksum=True) def test_plot(self, dataset: Chesapeake13) -> None: @@ -193,7 +194,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: ChesapeakeCVPR(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): ChesapeakeCVPR(str(tmp_path), checksum=True) def test_out_of_bounds_query(self, dataset: ChesapeakeCVPR) -> None: diff --git a/tests/datasets/test_cloud_cover.py b/tests/datasets/test_cloud_cover.py index 15fb078d49a..68e8511b3a9 100644 --- a/tests/datasets/test_cloud_cover.py +++ b/tests/datasets/test_cloud_cover.py @@ -12,7 +12,7 @@ import torch.nn as nn from pytest import MonkeyPatch -from torchgeo.datasets import CloudCoverDetection +from torchgeo.datasets import CloudCoverDetection, DatasetNotFoundError class Collection: @@ -83,7 +83,7 @@ def test_already_downloaded(self, dataset: CloudCoverDetection) -> None: CloudCoverDetection(root=dataset.root, split="test", download=True, api_key="") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): CloudCoverDetection(str(tmp_path)) def test_plot(self, dataset: CloudCoverDetection) -> None: diff --git a/tests/datasets/test_cms_mangrove_canopy.py b/tests/datasets/test_cms_mangrove_canopy.py index 3c9ea05e65a..1ebd33a1095 100644 --- a/tests/datasets/test_cms_mangrove_canopy.py +++ b/tests/datasets/test_cms_mangrove_canopy.py @@ -12,7 +12,12 @@ from pytest import MonkeyPatch from rasterio.crs import CRS -from torchgeo.datasets import CMSGlobalMangroveCanopy, IntersectionDataset, UnionDataset +from torchgeo.datasets import ( + CMSGlobalMangroveCanopy, + DatasetNotFoundError, + IntersectionDataset, + UnionDataset, +) def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -45,7 +50,7 @@ def test_getitem(self, dataset: CMSGlobalMangroveCanopy) -> None: assert isinstance(x["mask"], torch.Tensor) def test_no_dataset(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): CMSGlobalMangroveCanopy(str(tmp_path)) def test_already_downloaded(self, tmp_path: Path) -> None: diff --git a/tests/datasets/test_cowc.py b/tests/datasets/test_cowc.py index 6742ecfb211..19f448f5a27 100644 --- a/tests/datasets/test_cowc.py +++ b/tests/datasets/test_cowc.py @@ -14,8 +14,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import COWCCounting, COWCDetection -from torchgeo.datasets.cowc import COWC +from torchgeo.datasets import COWC, COWCCounting, COWCDetection, DatasetNotFoundError def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -78,7 +77,7 @@ def test_invalid_split(self) -> None: COWCCounting(split="foo") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): COWCCounting(str(tmp_path)) def test_plot(self, dataset: COWCCounting) -> None: @@ -142,7 +141,7 @@ def test_invalid_split(self) -> None: COWCDetection(split="foo") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): COWCDetection(str(tmp_path)) def test_plot(self, dataset: COWCDetection) -> None: diff --git a/tests/datasets/test_cv4a_kenya_crop_type.py b/tests/datasets/test_cv4a_kenya_crop_type.py index 638c63128c3..22667efbfd4 100644 --- a/tests/datasets/test_cv4a_kenya_crop_type.py +++ b/tests/datasets/test_cv4a_kenya_crop_type.py @@ -13,7 +13,7 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset -from torchgeo.datasets import CV4AKenyaCropType +from torchgeo.datasets import CV4AKenyaCropType, DatasetNotFoundError class Collection: @@ -84,7 +84,7 @@ def test_already_downloaded(self, dataset: CV4AKenyaCropType) -> None: CV4AKenyaCropType(root=dataset.root, download=True, api_key="") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): CV4AKenyaCropType(str(tmp_path)) def test_invalid_tile(self, dataset: CV4AKenyaCropType) -> None: diff --git a/tests/datasets/test_cyclone.py b/tests/datasets/test_cyclone.py index 56788bab37f..6ab894c1fb7 100644 --- a/tests/datasets/test_cyclone.py +++ b/tests/datasets/test_cyclone.py @@ -14,7 +14,7 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset -from torchgeo.datasets import TropicalCyclone +from torchgeo.datasets import DatasetNotFoundError, TropicalCyclone class Collection: @@ -80,7 +80,7 @@ def test_invalid_split(self) -> None: TropicalCyclone(split="foo") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): TropicalCyclone(str(tmp_path)) def test_plot(self, dataset: TropicalCyclone) -> None: diff --git a/tests/datasets/test_deepglobelandcover.py b/tests/datasets/test_deepglobelandcover.py index da243efc944..1ab9b70b2d1 100644 --- a/tests/datasets/test_deepglobelandcover.py +++ b/tests/datasets/test_deepglobelandcover.py @@ -12,7 +12,7 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -from torchgeo.datasets import DeepGlobeLandCover +from torchgeo.datasets import DatasetNotFoundError, DeepGlobeLandCover class TestDeepGlobeLandCover: @@ -55,12 +55,7 @@ def test_invalid_split(self) -> None: DeepGlobeLandCover(split="foo") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises( - RuntimeError, - match="Dataset not found in `root`, either" - + " specify a different `root` directory or manually download" - + " the dataset to this directory.", - ): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): DeepGlobeLandCover(str(tmp_path)) def test_plot(self, dataset: DeepGlobeLandCover) -> None: diff --git a/tests/datasets/test_dfc2022.py b/tests/datasets/test_dfc2022.py index 4b1a5221506..22caebcfb7b 100644 --- a/tests/datasets/test_dfc2022.py +++ b/tests/datasets/test_dfc2022.py @@ -12,7 +12,7 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -from torchgeo.datasets import DFC2022 +from torchgeo.datasets import DFC2022, DatasetNotFoundError class TestDFC2022: @@ -74,7 +74,7 @@ def test_invalid_split(self) -> None: DFC2022(split="foo") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): DFC2022(str(tmp_path)) def test_plot(self, dataset: DFC2022) -> None: diff --git a/tests/datasets/test_eddmaps.py b/tests/datasets/test_eddmaps.py index 9dcb4859d30..a15adbeecaf 100644 --- a/tests/datasets/test_eddmaps.py +++ b/tests/datasets/test_eddmaps.py @@ -6,7 +6,13 @@ import pytest -from torchgeo.datasets import BoundingBox, EDDMapS, IntersectionDataset, UnionDataset +from torchgeo.datasets import ( + BoundingBox, + DatasetNotFoundError, + EDDMapS, + IntersectionDataset, + UnionDataset, +) class TestEDDMapS: @@ -31,7 +37,7 @@ def test_or(self, dataset: EDDMapS) -> None: assert isinstance(ds, UnionDataset) def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(FileNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): EDDMapS(str(tmp_path)) def test_invalid_query(self, dataset: EDDMapS) -> None: diff --git a/tests/datasets/test_enviroatlas.py b/tests/datasets/test_enviroatlas.py index 8f65119fa84..da7641b47a8 100644 --- a/tests/datasets/test_enviroatlas.py +++ b/tests/datasets/test_enviroatlas.py @@ -16,6 +16,7 @@ import torchgeo.datasets.utils from torchgeo.datasets import ( BoundingBox, + DatasetNotFoundError, EnviroAtlas, IntersectionDataset, UnionDataset, @@ -88,7 +89,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: EnviroAtlas(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): EnviroAtlas(str(tmp_path), checksum=True) def test_out_of_bounds_query(self, dataset: EnviroAtlas) -> None: diff --git a/tests/datasets/test_esri2020.py b/tests/datasets/test_esri2020.py index 60c963139a1..1e01e0ac11d 100644 --- a/tests/datasets/test_esri2020.py +++ b/tests/datasets/test_esri2020.py @@ -13,7 +13,13 @@ from rasterio.crs import CRS import torchgeo.datasets.utils -from torchgeo.datasets import BoundingBox, Esri2020, IntersectionDataset, UnionDataset +from torchgeo.datasets import ( + BoundingBox, + DatasetNotFoundError, + Esri2020, + IntersectionDataset, + UnionDataset, +) def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -60,7 +66,7 @@ def test_not_extracted(self, tmp_path: Path) -> None: Esri2020(str(tmp_path)) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): Esri2020(str(tmp_path), checksum=True) def test_and(self, dataset: Esri2020) -> None: diff --git a/tests/datasets/test_etci2021.py b/tests/datasets/test_etci2021.py index c386005f182..8ee695bbcab 100644 --- a/tests/datasets/test_etci2021.py +++ b/tests/datasets/test_etci2021.py @@ -13,7 +13,7 @@ from pytest import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import ETCI2021 +from torchgeo.datasets import ETCI2021, DatasetNotFoundError def download_url(url: str, root: str, *args: str) -> None: @@ -77,7 +77,7 @@ def test_invalid_split(self) -> None: ETCI2021(split="foo") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): ETCI2021(str(tmp_path)) def test_plot(self, dataset: ETCI2021) -> None: diff --git a/tests/datasets/test_eudem.py b/tests/datasets/test_eudem.py index e3a5efdbe25..9816ea3c84d 100644 --- a/tests/datasets/test_eudem.py +++ b/tests/datasets/test_eudem.py @@ -12,7 +12,13 @@ from pytest import MonkeyPatch from rasterio.crs import CRS -from torchgeo.datasets import EUDEM, BoundingBox, IntersectionDataset, UnionDataset +from torchgeo.datasets import ( + EUDEM, + BoundingBox, + DatasetNotFoundError, + IntersectionDataset, + UnionDataset, +) class TestEUDEM: @@ -41,7 +47,7 @@ def test_extracted_already(self, dataset: EUDEM) -> None: def test_no_dataset(self, tmp_path: Path) -> None: shutil.rmtree(tmp_path) os.makedirs(tmp_path) - with pytest.raises(RuntimeError, match="Dataset not found in"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): EUDEM(str(tmp_path)) def test_corrupted(self, tmp_path: Path) -> None: diff --git a/tests/datasets/test_eurosat.py b/tests/datasets/test_eurosat.py index c79b1f80bc6..5d92498b222 100644 --- a/tests/datasets/test_eurosat.py +++ b/tests/datasets/test_eurosat.py @@ -15,7 +15,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import EuroSAT, EuroSAT100 +from torchgeo.datasets import DatasetNotFoundError, EuroSAT, EuroSAT100 def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -92,10 +92,7 @@ def test_already_downloaded_not_extracted( EuroSAT(root=str(tmp_path), download=False) def test_not_downloaded(self, tmp_path: Path) -> None: - err = "Dataset not found in `root` directory and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - with pytest.raises(RuntimeError, match=err): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): EuroSAT(str(tmp_path)) def test_plot(self, dataset: EuroSAT) -> None: diff --git a/tests/datasets/test_fair1m.py b/tests/datasets/test_fair1m.py index 0983444ad82..fcb7d4f7711 100644 --- a/tests/datasets/test_fair1m.py +++ b/tests/datasets/test_fair1m.py @@ -13,7 +13,7 @@ from pytest import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import FAIR1M +from torchgeo.datasets import FAIR1M, DatasetNotFoundError def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None: @@ -120,7 +120,7 @@ def test_corrupted(self, tmp_path: Path, dataset: FAIR1M) -> None: def test_not_downloaded(self, tmp_path: Path, dataset: FAIR1M) -> None: shutil.rmtree(str(tmp_path)) - with pytest.raises(RuntimeError, match="Dataset not found in"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): FAIR1M(root=str(tmp_path), split=dataset.split) def test_plot(self, dataset: FAIR1M) -> None: diff --git a/tests/datasets/test_fire_risk.py b/tests/datasets/test_fire_risk.py index 8e42cfe9742..76689bf9e82 100644 --- a/tests/datasets/test_fire_risk.py +++ b/tests/datasets/test_fire_risk.py @@ -13,7 +13,7 @@ from pytest import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import FireRisk +from torchgeo.datasets import DatasetNotFoundError, FireRisk def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -56,7 +56,7 @@ def test_already_downloaded_not_extracted( FireRisk(root=str(tmp_path), download=False) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found in"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): FireRisk(str(tmp_path)) def test_plot(self, dataset: FireRisk) -> None: diff --git a/tests/datasets/test_forestdamage.py b/tests/datasets/test_forestdamage.py index 8e333879f37..47caaebe5e3 100644 --- a/tests/datasets/test_forestdamage.py +++ b/tests/datasets/test_forestdamage.py @@ -12,7 +12,7 @@ from pytest import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import ForestDamage +from torchgeo.datasets import DatasetNotFoundError, ForestDamage def download_url(url: str, root: str, *args: str) -> None: @@ -66,7 +66,7 @@ def test_corrupted(self, tmp_path: Path) -> None: ForestDamage(root=str(tmp_path), checksum=True) def test_not_found(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found in."): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): ForestDamage(str(tmp_path)) def test_plot(self, dataset: ForestDamage) -> None: diff --git a/tests/datasets/test_gbif.py b/tests/datasets/test_gbif.py index 379b781cc93..bf6923a6bc2 100644 --- a/tests/datasets/test_gbif.py +++ b/tests/datasets/test_gbif.py @@ -6,7 +6,13 @@ import pytest -from torchgeo.datasets import GBIF, BoundingBox, IntersectionDataset, UnionDataset +from torchgeo.datasets import ( + GBIF, + BoundingBox, + DatasetNotFoundError, + IntersectionDataset, + UnionDataset, +) class TestGBIF: @@ -31,7 +37,7 @@ def test_or(self, dataset: GBIF) -> None: assert isinstance(ds, UnionDataset) def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(FileNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): GBIF(str(tmp_path)) def test_invalid_query(self, dataset: GBIF) -> None: diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 6b73bb0e8a4..31b140e91f2 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -4,7 +4,7 @@ import pickle from collections.abc import Iterable from pathlib import Path -from typing import Union +from typing import Optional, Union import pytest import torch @@ -16,6 +16,7 @@ from torchgeo.datasets import ( NAIP, BoundingBox, + DatasetNotFoundError, GeoDataset, IntersectionDataset, NonGeoClassificationDataset, @@ -33,11 +34,13 @@ def __init__( bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5), crs: CRS = CRS.from_epsg(4087), res: float = 1, + paths: Optional[Union[str, Iterable[str]]] = None, ) -> None: super().__init__() self.index.insert(0, tuple(bounds)) self._crs = crs self.res = res + self.paths = paths or [] def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: hits = self.index.intersection(tuple(query), objects=True) @@ -152,6 +155,23 @@ def test_and_nongeo(self, dataset: GeoDataset) -> None: ): dataset & ds2 # type: ignore[operator] + def test_files_property_for_non_existing_file_or_dir(self, tmp_path: Path) -> None: + paths = [str(tmp_path), str(tmp_path / "non_existing_file.tif")] + with pytest.warns(UserWarning, match="Path was ignored."): + assert len(CustomGeoDataset(paths=paths).files) == 0 + + def test_files_property_for_virtual_files(self) -> None: + # Tests only a subset of schemes and combinations. + paths = [ + "file://directory/file.tif", + "zip://archive.zip!folder/file.tif", + "az://azure_bucket/prefix/file.tif", + "/vsiaz/azure_bucket/prefix/file.tif", + "zip+az://azure_bucket/prefix/archive.zip!folder_in_archive/file.tif", + "/vsizip//vsiaz/azure_bucket/prefix/archive.zip/folder_in_archive/file.tif", + ] + assert len(CustomGeoDataset(paths=paths).files) == len(paths) + class TestRasterDataset: @pytest.fixture(params=zip([["R", "G", "B"], None], [True, False])) @@ -243,7 +263,7 @@ def test_invalid_query(self, sentinel: Sentinel2) -> None: sentinel[query] def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(FileNotFoundError, match="No RasterDataset data was found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): RasterDataset(str(tmp_path)) def test_no_all_bands(self) -> None: @@ -308,7 +328,7 @@ def test_invalid_query(self, dataset: CustomVectorDataset) -> None: dataset[query] def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(FileNotFoundError, match="No VectorDataset data was found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): VectorDataset(str(tmp_path)) diff --git a/tests/datasets/test_gid15.py b/tests/datasets/test_gid15.py index 26d269c1354..e39619d8313 100644 --- a/tests/datasets/test_gid15.py +++ b/tests/datasets/test_gid15.py @@ -13,7 +13,7 @@ from pytest import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import GID15 +from torchgeo.datasets import GID15, DatasetNotFoundError def download_url(url: str, root: str, *args: str) -> None: @@ -58,7 +58,7 @@ def test_invalid_split(self) -> None: GID15(split="foo") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): GID15(str(tmp_path)) def test_plot(self, dataset: GID15) -> None: diff --git a/tests/datasets/test_globbiomass.py b/tests/datasets/test_globbiomass.py index c73675def91..5bffc3ff0a4 100644 --- a/tests/datasets/test_globbiomass.py +++ b/tests/datasets/test_globbiomass.py @@ -14,6 +14,7 @@ from torchgeo.datasets import ( BoundingBox, + DatasetNotFoundError, GlobBiomass, IntersectionDataset, UnionDataset, @@ -50,7 +51,7 @@ def test_already_extracted(self, dataset: GlobBiomass) -> None: GlobBiomass(dataset.paths) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): GlobBiomass(str(tmp_path), checksum=True) def test_corrupted(self, tmp_path: Path) -> None: diff --git a/tests/datasets/test_idtrees.py b/tests/datasets/test_idtrees.py index 3ceaa34ecfb..a6192ab012e 100644 --- a/tests/datasets/test_idtrees.py +++ b/tests/datasets/test_idtrees.py @@ -16,7 +16,7 @@ from pytest import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import IDTReeS +from torchgeo.datasets import DatasetNotFoundError, IDTReeS pytest.importorskip("laspy", minversion="2") @@ -91,10 +91,7 @@ def test_already_downloaded(self, dataset: IDTReeS) -> None: IDTReeS(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - err = "Dataset not found in `root` directory and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - with pytest.raises(RuntimeError, match=err): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): IDTReeS(str(tmp_path)) def test_not_extracted(self, tmp_path: Path) -> None: @@ -140,6 +137,7 @@ def test_plot(self, dataset: IDTReeS) -> None: def test_plot_las(self, dataset: IDTReeS) -> None: pyvista = pytest.importorskip("pyvista", minversion="0.34.2") + pyvista.OFF_SCREEN = True # Test point cloud without colors point_cloud = dataset.plot_las(index=0) diff --git a/tests/datasets/test_inaturalist.py b/tests/datasets/test_inaturalist.py index b7b3a436465..49c87d83f77 100644 --- a/tests/datasets/test_inaturalist.py +++ b/tests/datasets/test_inaturalist.py @@ -8,6 +8,7 @@ from torchgeo.datasets import ( BoundingBox, + DatasetNotFoundError, INaturalist, IntersectionDataset, UnionDataset, @@ -36,7 +37,7 @@ def test_or(self, dataset: INaturalist) -> None: assert isinstance(ds, UnionDataset) def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(FileNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): INaturalist(str(tmp_path)) def test_invalid_query(self, dataset: INaturalist) -> None: diff --git a/tests/datasets/test_inria.py b/tests/datasets/test_inria.py index e70cd060ee6..71739a0ec8b 100644 --- a/tests/datasets/test_inria.py +++ b/tests/datasets/test_inria.py @@ -11,7 +11,7 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -from torchgeo.datasets import InriaAerialImageLabeling +from torchgeo.datasets import DatasetNotFoundError, InriaAerialImageLabeling class TestInriaAerialImageLabeling: @@ -49,7 +49,7 @@ def test_already_downloaded(self, dataset: InriaAerialImageLabeling) -> None: InriaAerialImageLabeling(root=dataset.root) def test_not_downloaded(self, tmp_path: str) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): InriaAerialImageLabeling(str(tmp_path)) def test_dataset_checksum(self, dataset: InriaAerialImageLabeling) -> None: diff --git a/tests/datasets/test_l7irish.py b/tests/datasets/test_l7irish.py index 43795d57a1e..8b0e9a0c64f 100644 --- a/tests/datasets/test_l7irish.py +++ b/tests/datasets/test_l7irish.py @@ -14,7 +14,13 @@ from rasterio.crs import CRS import torchgeo.datasets.utils -from torchgeo.datasets import BoundingBox, IntersectionDataset, L7Irish, UnionDataset +from torchgeo.datasets import ( + BoundingBox, + DatasetNotFoundError, + IntersectionDataset, + L7Irish, + UnionDataset, +) def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -68,7 +74,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: L7Irish(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): L7Irish(str(tmp_path)) def test_plot_prediction(self, dataset: L7Irish) -> None: diff --git a/tests/datasets/test_l8biome.py b/tests/datasets/test_l8biome.py index 337e17c1b64..ca57b70fa4e 100644 --- a/tests/datasets/test_l8biome.py +++ b/tests/datasets/test_l8biome.py @@ -14,7 +14,13 @@ from rasterio.crs import CRS import torchgeo.datasets.utils -from torchgeo.datasets import BoundingBox, IntersectionDataset, L8Biome, UnionDataset +from torchgeo.datasets import ( + BoundingBox, + DatasetNotFoundError, + IntersectionDataset, + L8Biome, + UnionDataset, +) def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -68,7 +74,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: L8Biome(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): L8Biome(str(tmp_path)) def test_plot_prediction(self, dataset: L8Biome) -> None: diff --git a/tests/datasets/test_landcoverai.py b/tests/datasets/test_landcoverai.py index 3e0b6d2434a..e5bb366e5a2 100644 --- a/tests/datasets/test_landcoverai.py +++ b/tests/datasets/test_landcoverai.py @@ -14,7 +14,12 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import BoundingBox, LandCoverAI, LandCoverAIGeo +from torchgeo.datasets import ( + BoundingBox, + DatasetNotFoundError, + LandCoverAI, + LandCoverAIGeo, +) def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -49,7 +54,7 @@ def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> N LandCoverAIGeo(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): LandCoverAIGeo(str(tmp_path)) def test_out_of_bounds_query(self, dataset: LandCoverAIGeo) -> None: @@ -115,7 +120,7 @@ def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> N LandCoverAI(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): LandCoverAI(str(tmp_path)) def test_invalid_split(self) -> None: diff --git a/tests/datasets/test_landsat.py b/tests/datasets/test_landsat.py index 950d33fcb00..f85de17e2a1 100644 --- a/tests/datasets/test_landsat.py +++ b/tests/datasets/test_landsat.py @@ -12,7 +12,13 @@ from pytest import MonkeyPatch from rasterio.crs import CRS -from torchgeo.datasets import BoundingBox, IntersectionDataset, Landsat8, UnionDataset +from torchgeo.datasets import ( + BoundingBox, + DatasetNotFoundError, + IntersectionDataset, + Landsat8, + UnionDataset, +) class TestLandsat8: @@ -60,7 +66,7 @@ def test_plot_wrong_bands(self, dataset: Landsat8) -> None: ds.plot(x) def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(FileNotFoundError, match="No Landsat8 data was found in "): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): Landsat8(str(tmp_path)) def test_invalid_query(self, dataset: Landsat8) -> None: diff --git a/tests/datasets/test_levircd.py b/tests/datasets/test_levircd.py index b4a46d43e62..cafe1ed8206 100644 --- a/tests/datasets/test_levircd.py +++ b/tests/datasets/test_levircd.py @@ -13,7 +13,7 @@ from pytest import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import LEVIRCDPlus +from torchgeo.datasets import DatasetNotFoundError, LEVIRCDPlus def download_url(url: str, root: str, *args: str) -> None: @@ -38,9 +38,11 @@ def dataset( def test_getitem(self, dataset: LEVIRCDPlus) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["image1"], torch.Tensor) + assert isinstance(x["image2"], torch.Tensor) assert isinstance(x["mask"], torch.Tensor) - assert x["image"].shape[0] == 2 + assert x["image1"].shape[0] == 3 + assert x["image2"].shape[0] == 3 def test_len(self, dataset: LEVIRCDPlus) -> None: assert len(dataset) == 2 @@ -53,7 +55,7 @@ def test_invalid_split(self) -> None: LEVIRCDPlus(split="foo") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): LEVIRCDPlus(str(tmp_path)) def test_plot(self, dataset: LEVIRCDPlus) -> None: diff --git a/tests/datasets/test_loveda.py b/tests/datasets/test_loveda.py index 666afce52ad..a368d711034 100644 --- a/tests/datasets/test_loveda.py +++ b/tests/datasets/test_loveda.py @@ -13,7 +13,7 @@ from pytest import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import LoveDA +from torchgeo.datasets import DatasetNotFoundError, LoveDA def download_url(url: str, root: str, *args: str) -> None: @@ -83,9 +83,7 @@ def test_invalid_scene(self) -> None: LoveDA(scene=["garden"]) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises( - RuntimeError, match="Dataset not found at root directory or corrupted." - ): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): LoveDA(str(tmp_path)) def test_plot(self, dataset: LoveDA) -> None: diff --git a/tests/datasets/test_mapinwild.py b/tests/datasets/test_mapinwild.py index 7f4b4192122..90aa35b6aa2 100644 --- a/tests/datasets/test_mapinwild.py +++ b/tests/datasets/test_mapinwild.py @@ -15,7 +15,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import MapInWild +from torchgeo.datasets import DatasetNotFoundError, MapInWild def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -97,7 +97,7 @@ def test_invalid_split(self) -> None: MapInWild(split="foo") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): MapInWild(root=str(tmp_path)) def test_downloaded_not_extracted(self, tmp_path: Path) -> None: diff --git a/tests/datasets/test_millionaid.py b/tests/datasets/test_millionaid.py index 751567e28a8..1e94fd003d0 100644 --- a/tests/datasets/test_millionaid.py +++ b/tests/datasets/test_millionaid.py @@ -11,7 +11,7 @@ import torch.nn as nn from _pytest.fixtures import SubRequest -from torchgeo.datasets import MillionAID +from torchgeo.datasets import DatasetNotFoundError, MillionAID class TestMillionAID: @@ -38,7 +38,7 @@ def test_len(self, dataset: MillionAID) -> None: assert len(dataset) == 2 def test_not_found(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found in"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): MillionAID(str(tmp_path)) def test_not_extracted(self, tmp_path: Path) -> None: diff --git a/tests/datasets/test_naip.py b/tests/datasets/test_naip.py index 11e72938883..fe257ae2b78 100644 --- a/tests/datasets/test_naip.py +++ b/tests/datasets/test_naip.py @@ -10,7 +10,13 @@ import torch.nn as nn from rasterio.crs import CRS -from torchgeo.datasets import NAIP, BoundingBox, IntersectionDataset, UnionDataset +from torchgeo.datasets import ( + NAIP, + BoundingBox, + DatasetNotFoundError, + IntersectionDataset, + UnionDataset, +) class TestNAIP: @@ -41,7 +47,7 @@ def test_plot(self, dataset: NAIP) -> None: plt.close() def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(FileNotFoundError, match="No NAIP data was found in "): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): NAIP(str(tmp_path)) def test_invalid_query(self, dataset: NAIP) -> None: diff --git a/tests/datasets/test_nasa_marine_debris.py b/tests/datasets/test_nasa_marine_debris.py index 05c96c8dcb1..f475234ffe6 100644 --- a/tests/datasets/test_nasa_marine_debris.py +++ b/tests/datasets/test_nasa_marine_debris.py @@ -12,7 +12,7 @@ import torch.nn as nn from pytest import MonkeyPatch -from torchgeo.datasets import NASAMarineDebris +from torchgeo.datasets import DatasetNotFoundError, NASAMarineDebris class Collection: @@ -90,10 +90,7 @@ def test_corrupted_new_download( NASAMarineDebris(root=str(tmp_path), download=True, checksum=True) def test_not_downloaded(self, tmp_path: Path) -> None: - err = "Dataset not found in `root` directory and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - with pytest.raises(RuntimeError, match=err): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): NASAMarineDebris(str(tmp_path)) def test_plot(self, dataset: NASAMarineDebris) -> None: diff --git a/tests/datasets/test_nlcd.py b/tests/datasets/test_nlcd.py index ceee8097634..67dde52648d 100644 --- a/tests/datasets/test_nlcd.py +++ b/tests/datasets/test_nlcd.py @@ -13,7 +13,13 @@ from rasterio.crs import CRS import torchgeo.datasets.utils -from torchgeo.datasets import NLCD, BoundingBox, IntersectionDataset, UnionDataset +from torchgeo.datasets import ( + NLCD, + BoundingBox, + DatasetNotFoundError, + IntersectionDataset, + UnionDataset, +) def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -107,7 +113,7 @@ def test_plot_prediction(self, dataset: NLCD) -> None: plt.close() def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): NLCD(str(tmp_path)) def test_invalid_query(self, dataset: NLCD) -> None: diff --git a/tests/datasets/test_openbuildings.py b/tests/datasets/test_openbuildings.py index 4611599954f..65244962553 100644 --- a/tests/datasets/test_openbuildings.py +++ b/tests/datasets/test_openbuildings.py @@ -16,6 +16,7 @@ from torchgeo.datasets import ( BoundingBox, + DatasetNotFoundError, IntersectionDataset, OpenBuildings, UnionDataset, @@ -52,16 +53,9 @@ def test_no_shapes_to_rasterize( assert isinstance(x["crs"], CRS) assert isinstance(x["mask"], torch.Tensor) - def test_no_building_data_found(self, tmp_path: Path) -> None: - false_root = os.path.join(tmp_path, "empty") - os.makedirs(false_root) - shutil.copy( - os.path.join("tests", "data", "openbuildings", "tiles.geojson"), false_root - ) - with pytest.raises( - RuntimeError, match="have manually downloaded the dataset as suggested " - ): - OpenBuildings(false_root) + def test_not_download(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + OpenBuildings(str(tmp_path)) def test_corrupted(self, dataset: OpenBuildings, tmp_path: Path) -> None: with open(os.path.join(tmp_path, "000_buildings.csv.gz"), "w") as f: @@ -69,12 +63,6 @@ def test_corrupted(self, dataset: OpenBuildings, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): OpenBuildings(dataset.paths, checksum=True) - def test_no_meta_data_found(self, tmp_path: Path) -> None: - false_root = os.path.join(tmp_path, "empty") - os.makedirs(false_root) - with pytest.raises(FileNotFoundError, match="Meta data file"): - OpenBuildings(false_root) - def test_nothing_in_index(self, dataset: OpenBuildings, tmp_path: Path) -> None: # change meta data to another 'title_url' so that there is no match found with open(os.path.join(tmp_path, "tiles.geojson")) as f: @@ -84,7 +72,7 @@ def test_nothing_in_index(self, dataset: OpenBuildings, tmp_path: Path) -> None: with open(os.path.join(tmp_path, "tiles.geojson"), "w") as f: json.dump(content, f) - with pytest.raises(FileNotFoundError, match="data was found in"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): OpenBuildings(dataset.paths) def test_getitem(self, dataset: OpenBuildings) -> None: diff --git a/tests/datasets/test_oscd.py b/tests/datasets/test_oscd.py index 78cbd30b8f7..82501f016d3 100644 --- a/tests/datasets/test_oscd.py +++ b/tests/datasets/test_oscd.py @@ -15,7 +15,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import OSCD +from torchgeo.datasets import OSCD, DatasetNotFoundError def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -23,7 +23,7 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestOSCD: - @pytest.fixture(params=zip(["all", "rgb"], ["train", "test"])) + @pytest.fixture(params=zip([OSCD.all_bands, OSCD.rgb_bands], ["train", "test"])) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> OSCD: @@ -72,15 +72,19 @@ def dataset( def test_getitem(self, dataset: OSCD) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert x["image"].ndim == 3 + assert isinstance(x["image1"], torch.Tensor) + assert x["image1"].ndim == 3 + assert isinstance(x["image2"], torch.Tensor) + assert x["image2"].ndim == 3 assert isinstance(x["mask"], torch.Tensor) assert x["mask"].ndim == 2 - if dataset.bands == "rgb": - assert x["image"].shape[0] == 6 + if dataset.bands == OSCD.rgb_bands: + assert x["image1"].shape[0] == 3 + assert x["image2"].shape[0] == 3 else: - assert x["image"].shape[0] == 26 + assert x["image1"].shape[0] == 13 + assert x["image2"].shape[0] == 13 def test_len(self, dataset: OSCD) -> None: if dataset.split == "train": @@ -103,9 +107,15 @@ def test_already_downloaded(self, tmp_path: Path) -> None: OSCD(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): OSCD(str(tmp_path)) def test_plot(self, dataset: OSCD) -> None: dataset.plot(dataset[0], suptitle="Test") plt.close() + + def test_failed_plot(self, dataset: OSCD) -> None: + single_band_dataset = OSCD(root=dataset.root, bands=("B01",)) + with pytest.raises(ValueError, match="RGB bands must be present"): + x = single_band_dataset[0].copy() + single_band_dataset.plot(x, suptitle="Test") diff --git a/tests/datasets/test_pastis.py b/tests/datasets/test_pastis.py index 698d12487b5..1decc20e0c8 100644 --- a/tests/datasets/test_pastis.py +++ b/tests/datasets/test_pastis.py @@ -14,7 +14,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import PASTIS +from torchgeo.datasets import PASTIS, DatasetNotFoundError def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -80,7 +80,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: PASTIS(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): PASTIS(str(tmp_path)) def test_corrupted(self, tmp_path: Path) -> None: diff --git a/tests/datasets/test_patternnet.py b/tests/datasets/test_patternnet.py index 7e06264bf92..efab8bd7b31 100644 --- a/tests/datasets/test_patternnet.py +++ b/tests/datasets/test_patternnet.py @@ -12,7 +12,7 @@ from pytest import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import PatternNet +from torchgeo.datasets import DatasetNotFoundError, PatternNet def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -52,10 +52,7 @@ def test_already_downloaded_not_extracted( PatternNet(root=str(tmp_path), download=False) def test_not_downloaded(self, tmp_path: Path) -> None: - err = "Dataset not found in `root` directory and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - with pytest.raises(RuntimeError, match=err): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): PatternNet(str(tmp_path)) def test_plot(self, dataset: PatternNet) -> None: diff --git a/tests/datasets/test_potsdam.py b/tests/datasets/test_potsdam.py index 7502a3db63b..b803b15ea95 100644 --- a/tests/datasets/test_potsdam.py +++ b/tests/datasets/test_potsdam.py @@ -12,7 +12,7 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -from torchgeo.datasets import Potsdam2D +from torchgeo.datasets import DatasetNotFoundError, Potsdam2D class TestPotsdam2D: @@ -60,7 +60,7 @@ def test_invalid_split(self) -> None: Potsdam2D(split="foo") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): Potsdam2D(str(tmp_path)) def test_plot(self, dataset: Potsdam2D) -> None: diff --git a/tests/datasets/test_reforestree.py b/tests/datasets/test_reforestree.py index 6a5236cdd38..b2f3a16eaef 100644 --- a/tests/datasets/test_reforestree.py +++ b/tests/datasets/test_reforestree.py @@ -12,7 +12,7 @@ from pytest import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import ReforesTree +from torchgeo.datasets import DatasetNotFoundError, ReforesTree def download_url(url: str, root: str, *args: str) -> None: @@ -66,7 +66,7 @@ def test_corrupted(self, tmp_path: Path) -> None: ReforesTree(root=str(tmp_path), checksum=True) def test_not_found(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found in"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): ReforesTree(str(tmp_path)) def test_plot(self, dataset: ReforesTree) -> None: diff --git a/tests/datasets/test_resisc45.py b/tests/datasets/test_resisc45.py index b20a19ddbb4..099885deac8 100644 --- a/tests/datasets/test_resisc45.py +++ b/tests/datasets/test_resisc45.py @@ -13,7 +13,7 @@ from pytest import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import RESISC45 +from torchgeo.datasets import RESISC45, DatasetNotFoundError def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -78,10 +78,7 @@ def test_already_downloaded_not_extracted( RESISC45(root=str(tmp_path), download=False) def test_not_downloaded(self, tmp_path: Path) -> None: - err = "Dataset not found in `root` directory and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - with pytest.raises(RuntimeError, match=err): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): RESISC45(str(tmp_path)) def test_plot(self, dataset: RESISC45) -> None: diff --git a/tests/datasets/test_rwanda_field_boundary.py b/tests/datasets/test_rwanda_field_boundary.py index e0736b32e7c..c0bfd71e452 100644 --- a/tests/datasets/test_rwanda_field_boundary.py +++ b/tests/datasets/test_rwanda_field_boundary.py @@ -14,7 +14,7 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset -from torchgeo.datasets import RwandaFieldBoundary +from torchgeo.datasets import DatasetNotFoundError, RwandaFieldBoundary class Collection: @@ -87,7 +87,7 @@ def test_already_downloaded(self, dataset: RwandaFieldBoundary) -> None: RwandaFieldBoundary(root=dataset.root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found in"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): RwandaFieldBoundary(str(tmp_path)) def test_corrupted(self, tmp_path: Path) -> None: diff --git a/tests/datasets/test_seasonet.py b/tests/datasets/test_seasonet.py index f227318791c..6d7280537ab 100644 --- a/tests/datasets/test_seasonet.py +++ b/tests/datasets/test_seasonet.py @@ -15,7 +15,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import SeasoNet +from torchgeo.datasets import DatasetNotFoundError, SeasoNet def download_url(url: str, root: str, md5: str, *args: str, **kwargs: str) -> None: @@ -147,7 +147,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: SeasoNet(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="not found in"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): SeasoNet(str(tmp_path), download=False) def test_out_of_bounds(self, dataset: SeasoNet) -> None: diff --git a/tests/datasets/test_seco.py b/tests/datasets/test_seco.py index 743a90c2b7c..f89efec4b24 100644 --- a/tests/datasets/test_seco.py +++ b/tests/datasets/test_seco.py @@ -15,7 +15,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import SeasonalContrastS2 +from torchgeo.datasets import DatasetNotFoundError, SeasonalContrastS2 def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -98,7 +98,7 @@ def test_invalid_band(self) -> None: SeasonalContrastS2(bands=["A1steaksauce"]) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): SeasonalContrastS2(str(tmp_path)) def test_plot(self, dataset: SeasonalContrastS2) -> None: diff --git a/tests/datasets/test_sen12ms.py b/tests/datasets/test_sen12ms.py index 55ecb406bf2..f802105e0c6 100644 --- a/tests/datasets/test_sen12ms.py +++ b/tests/datasets/test_sen12ms.py @@ -12,7 +12,7 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset -from torchgeo.datasets import SEN12MS +from torchgeo.datasets import SEN12MS, DatasetNotFoundError class TestSEN12MS: @@ -65,10 +65,10 @@ def test_invalid_split(self) -> None: SEN12MS(split="foo") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): SEN12MS(str(tmp_path), checksum=True) - with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): SEN12MS(str(tmp_path), checksum=False) def test_check_integrity_light(self) -> None: diff --git a/tests/datasets/test_sentinel.py b/tests/datasets/test_sentinel.py index fccb4e32032..f22a1c5fcc3 100644 --- a/tests/datasets/test_sentinel.py +++ b/tests/datasets/test_sentinel.py @@ -13,6 +13,7 @@ from torchgeo.datasets import ( BoundingBox, + DatasetNotFoundError, IntersectionDataset, Sentinel1, Sentinel2, @@ -64,7 +65,7 @@ def test_plot(self, dataset: Sentinel2) -> None: plt.close() def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(FileNotFoundError, match="No Sentinel1 data was found in "): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): Sentinel1(str(tmp_path)) def test_empty_bands(self) -> None: @@ -123,7 +124,7 @@ def test_or(self, dataset: Sentinel2) -> None: assert isinstance(ds, UnionDataset) def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(FileNotFoundError, match="No Sentinel2 data was found in "): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): Sentinel2(str(tmp_path)) def test_plot(self, dataset: Sentinel2) -> None: diff --git a/tests/datasets/test_skippd.py b/tests/datasets/test_skippd.py index 838f4d67032..392c3255eda 100644 --- a/tests/datasets/test_skippd.py +++ b/tests/datasets/test_skippd.py @@ -16,7 +16,7 @@ from pytest import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import SKIPPD +from torchgeo.datasets import SKIPPD, DatasetNotFoundError pytest.importorskip("h5py", minversion="3") @@ -105,7 +105,7 @@ def test_invalid_split(self) -> None: SKIPPD(split="foo") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found in"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): SKIPPD(str(tmp_path)) def test_plot(self, dataset: SKIPPD) -> None: diff --git a/tests/datasets/test_so2sat.py b/tests/datasets/test_so2sat.py index 5802d81b537..2e093f288c6 100644 --- a/tests/datasets/test_so2sat.py +++ b/tests/datasets/test_so2sat.py @@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -from torchgeo.datasets import So2Sat +from torchgeo.datasets import DatasetNotFoundError, So2Sat pytest.importorskip("h5py", minversion="3") @@ -70,7 +70,7 @@ def test_invalid_bands(self) -> None: So2Sat(bands=("OK", "BK")) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): So2Sat(str(tmp_path)) def test_plot(self, dataset: So2Sat) -> None: diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py new file mode 100644 index 00000000000..50be0b291f2 --- /dev/null +++ b/tests/datasets/test_south_america_soybean.py @@ -0,0 +1,110 @@ +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from pytest import MonkeyPatch +from rasterio.crs import CRS + +import torchgeo.datasets.utils +from torchgeo.datasets import south_america_soybean, BoundingBox, IntersectionDataset, UnionDataset + + +def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: + shutil.copy(url, root) + + +class TestSouthAmericaSoybean: + @pytest.fixture + def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> south_america_soybean: + monkeypatch.setattr(torchgeo.datasets.southamerica_soybean, "download_url", download_url) + transforms = nn.Identity() + md5s = { + 2002: "8a4a9dcea54b3ec7de07657b9f2c0893", + 2021: "edff3ada13a1a9910d1fe844d28ae4f", + } + monkeypatch.setattr(south_america_soybean, "md5s", md5s) + + url = os.path.join("tests", "data", "south_america_soybean", "SouthAmerica_Soybean_{}.tif") + monkeypatch.setattr(south_america_soybean, "url", url) + + + return south_america_soybean( + transforms=transforms, + download=True, + checksum=True, + years=[2002, 2021], + ) + + def test_getitem(self, dataset: south_america_soybean) -> None: + x = dataset[dataset.bounds] + assert isinstance(x, dict) + assert isinstance(x["crs"], CRS) + assert isinstance(x["mask"], torch.Tensor) + + def test_classes(self) -> None: + root = os.path.join("tests", "data", "southamerica_soybean") + classes = list(south_america_soybean.cmap.keys())[0:2] + ds = south_america_soybean(root, years=[2021], classes=classes) + sample = ds[ds.bounds] + mask = sample["mask"] + assert mask.max() < len(classes) + + def test_and(self, dataset: south_america_soybean) -> None: + ds = dataset & dataset + assert isinstance(ds, IntersectionDataset) + + def test_or(self, dataset: south_america_soybean) -> None: + ds = dataset | dataset + assert isinstance(ds, UnionDataset) + + def test_already_extracted(self, dataset: south_america_soybean) -> None: + south_america_soybean(dataset.paths, download=True, years=[2021]) + + def test_already_downloaded(self, tmp_path: Path) -> None: + pathname = os.path.join("tests", "data", "southamerica_soybean", "SouthAmerica_Soybean_2021.tif") + root = str(tmp_path) + + shutil.copy(pathname, root) + south_america_soybean(root, years=[2021]) + + def test_invalid_year(self, tmp_path: Path) -> None: + with pytest.raises( + AssertionError, + match="south_america_soybean data product only exists for the following years:", + ): + south_america_soybean(str(tmp_path), years=[1996]) + + def test_invalid_classes(self) -> None: + with pytest.raises(AssertionError): + south_america_soybean(classes=[-1]) + + with pytest.raises(AssertionError): + south_america_soybean(classes=[11]) + + def test_plot(self, dataset: south_america_soybean) -> None: + query = dataset.bounds + x = dataset[query] + dataset.plot(x, suptitle="Test") + plt.close() + + def test_plot_prediction(self, dataset: south_america_soybean) -> None: + query = dataset.bounds + x = dataset[query] + x["prediction"] = x["mask"].clone() + dataset.plot(x, suptitle="Prediction") + plt.close() + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found"): + south_america_soybean(str(tmp_path)) + + def test_invalid_query(self, dataset: south_america_soybean) -> None: + query = BoundingBox(0, 0, 0, 0, 0, 0) + with pytest.raises( + IndexError, match="query: .* not found in index with bounds:" + ): + dataset[query] \ No newline at end of file diff --git a/tests/datasets/test_spacenet.py b/tests/datasets/test_spacenet.py index 4b79ca3fa35..046b83cfba1 100644 --- a/tests/datasets/test_spacenet.py +++ b/tests/datasets/test_spacenet.py @@ -14,6 +14,7 @@ from pytest import MonkeyPatch from torchgeo.datasets import ( + DatasetNotFoundError, SpaceNet1, SpaceNet2, SpaceNet3, @@ -91,7 +92,7 @@ def test_already_downloaded(self, dataset: SpaceNet1) -> None: SpaceNet1(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): SpaceNet1(str(tmp_path)) def test_plot(self, dataset: SpaceNet1) -> None: @@ -147,7 +148,7 @@ def test_already_downloaded(self, dataset: SpaceNet2) -> None: SpaceNet2(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): SpaceNet2(str(tmp_path)) def test_collection_checksum(self, dataset: SpaceNet2) -> None: @@ -207,7 +208,7 @@ def test_already_downloaded(self, dataset: SpaceNet3) -> None: SpaceNet3(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): SpaceNet3(str(tmp_path)) def test_collection_checksum(self, dataset: SpaceNet3) -> None: @@ -271,7 +272,7 @@ def test_already_downloaded(self, dataset: SpaceNet4) -> None: SpaceNet4(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): SpaceNet4(str(tmp_path)) def test_collection_checksum(self, dataset: SpaceNet4) -> None: @@ -333,7 +334,7 @@ def test_already_downloaded(self, dataset: SpaceNet5) -> None: SpaceNet5(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): SpaceNet5(str(tmp_path)) def test_collection_checksum(self, dataset: SpaceNet5) -> None: @@ -427,7 +428,7 @@ def test_already_downloaded(self, dataset: SpaceNet4) -> None: SpaceNet7(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): SpaceNet7(str(tmp_path)) def test_collection_checksum(self, dataset: SpaceNet4) -> None: diff --git a/tests/datasets/test_ssl4eo.py b/tests/datasets/test_ssl4eo.py index e2b9b36feff..68b6df002b4 100644 --- a/tests/datasets/test_ssl4eo.py +++ b/tests/datasets/test_ssl4eo.py @@ -15,7 +15,7 @@ from torch.utils.data import ConcatDataset import torchgeo -from torchgeo.datasets import SSL4EOL, SSL4EOS12 +from torchgeo.datasets import SSL4EOL, SSL4EOS12, DatasetNotFoundError def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -94,7 +94,7 @@ def test_already_downloaded(self, dataset: SSL4EOL, tmp_path: Path) -> None: SSL4EOL(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): SSL4EOL(str(tmp_path)) def test_invalid_split(self) -> None: @@ -155,7 +155,7 @@ def test_invalid_split(self) -> None: SSL4EOS12(split="foo") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): SSL4EOS12(str(tmp_path)) def test_plot(self, dataset: SSL4EOS12) -> None: diff --git a/tests/datasets/test_ssl4eo_benchmark.py b/tests/datasets/test_ssl4eo_benchmark.py index 1cc1809f80d..0d5b3f94030 100644 --- a/tests/datasets/test_ssl4eo_benchmark.py +++ b/tests/datasets/test_ssl4eo_benchmark.py @@ -16,7 +16,13 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import CDL, NLCD, RasterDataset, SSL4EOLBenchmark +from torchgeo.datasets import ( + CDL, + NLCD, + DatasetNotFoundError, + RasterDataset, + SSL4EOLBenchmark, +) def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -137,7 +143,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: SSL4EOLBenchmark(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): SSL4EOLBenchmark(str(tmp_path)) def test_plot(self, dataset: SSL4EOLBenchmark) -> None: diff --git a/tests/datasets/test_sustainbench_crop_yield.py b/tests/datasets/test_sustainbench_crop_yield.py index 04c056ed505..071f0c81a8f 100644 --- a/tests/datasets/test_sustainbench_crop_yield.py +++ b/tests/datasets/test_sustainbench_crop_yield.py @@ -13,7 +13,7 @@ from pytest import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import SustainBenchCropYield +from torchgeo.datasets import DatasetNotFoundError, SustainBenchCropYield def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -71,7 +71,7 @@ def test_invalid_split(self) -> None: SustainBenchCropYield(split="foo") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found in"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): SustainBenchCropYield(str(tmp_path)) def test_plot(self, dataset: SustainBenchCropYield) -> None: diff --git a/tests/datasets/test_ucmerced.py b/tests/datasets/test_ucmerced.py index c4096276725..61c76f9cecd 100644 --- a/tests/datasets/test_ucmerced.py +++ b/tests/datasets/test_ucmerced.py @@ -14,7 +14,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import UCMerced +from torchgeo.datasets import DatasetNotFoundError, UCMerced def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -81,10 +81,7 @@ def test_already_downloaded_not_extracted( UCMerced(root=str(tmp_path), download=False) def test_not_downloaded(self, tmp_path: Path) -> None: - err = "Dataset not found in `root` directory and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - with pytest.raises(RuntimeError, match=err): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): UCMerced(str(tmp_path)) def test_plot(self, dataset: UCMerced) -> None: diff --git a/tests/datasets/test_usavars.py b/tests/datasets/test_usavars.py index 3754b239300..4c256ad5c25 100644 --- a/tests/datasets/test_usavars.py +++ b/tests/datasets/test_usavars.py @@ -14,7 +14,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import USAVars +from torchgeo.datasets import DatasetNotFoundError, USAVars def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -129,7 +129,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: USAVars(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): USAVars(str(tmp_path)) def test_plot(self, dataset: USAVars) -> None: diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index 0346d13a2b0..be5950fb071 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -18,10 +18,12 @@ import torch from pytest import MonkeyPatch from rasterio.crs import CRS +from torch.utils.data import Dataset import torchgeo.datasets.utils from torchgeo.datasets.utils import ( BoundingBox, + DatasetNotFoundError, concat_samples, disambiguate_timestamp, download_and_extract_archive, @@ -36,6 +38,52 @@ ) +class TestDatasetNotFoundError: + def test_none(self) -> None: + ds: Dataset[Any] = Dataset() + match = "Dataset not found." + with pytest.raises(DatasetNotFoundError, match=match): + raise DatasetNotFoundError(ds) + + def test_root(self) -> None: + ds: Dataset[Any] = Dataset() + ds.root = "foo" # type: ignore[attr-defined] + match = "Dataset not found in `root='foo'` and cannot be automatically " + match += "downloaded, either specify a different `root` or manually " + match += "download the dataset." + with pytest.raises(DatasetNotFoundError, match=match): + raise DatasetNotFoundError(ds) + + def test_paths(self) -> None: + ds: Dataset[Any] = Dataset() + ds.paths = "foo" # type: ignore[attr-defined] + match = "Dataset not found in `paths='foo'` and cannot be automatically " + match += "downloaded, either specify a different `paths` or manually " + match += "download the dataset." + with pytest.raises(DatasetNotFoundError, match=match): + raise DatasetNotFoundError(ds) + + def test_root_download(self) -> None: + ds: Dataset[Any] = Dataset() + ds.root = "foo" # type: ignore[attr-defined] + ds.download = False # type: ignore[attr-defined] + match = "Dataset not found in `root='foo'` and `download=False`, either " + match += "specify a different `root` or use `download=True` to automatically " + match += "download the dataset." + with pytest.raises(DatasetNotFoundError, match=match): + raise DatasetNotFoundError(ds) + + def test_paths_download(self) -> None: + ds: Dataset[Any] = Dataset() + ds.paths = "foo" # type: ignore[attr-defined] + ds.download = False # type: ignore[attr-defined] + match = "Dataset not found in `paths='foo'` and `download=False`, either " + match += "specify a different `paths` or use `download=True` to automatically " + match += "download the dataset." + with pytest.raises(DatasetNotFoundError, match=match): + raise DatasetNotFoundError(ds) + + @pytest.fixture def mock_missing_module(monkeypatch: MonkeyPatch) -> None: import_orig = builtins.__import__ @@ -48,7 +96,7 @@ def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: monkeypatch.setattr(builtins, "__import__", mocked_import) -class Dataset: +class MLHubDataset: def download(self, output_dir: str, **kwargs: str) -> None: glob_path = os.path.join( "tests", "data", "ref_african_crops_kenya_02", "*.tar.gz" @@ -66,8 +114,8 @@ def download(self, output_dir: str, **kwargs: str) -> None: shutil.copy(tarball, output_dir) -def fetch_dataset(dataset_id: str, **kwargs: str) -> Dataset: - return Dataset() +def fetch_dataset(dataset_id: str, **kwargs: str) -> MLHubDataset: + return MLHubDataset() def fetch_collection(collection_id: str, **kwargs: str) -> Collection: diff --git a/tests/datasets/test_vaihingen.py b/tests/datasets/test_vaihingen.py index 56240b2aca6..fe34bccea08 100644 --- a/tests/datasets/test_vaihingen.py +++ b/tests/datasets/test_vaihingen.py @@ -12,7 +12,7 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -from torchgeo.datasets import Vaihingen2D +from torchgeo.datasets import DatasetNotFoundError, Vaihingen2D class TestVaihingen2D: @@ -69,7 +69,7 @@ def test_invalid_split(self) -> None: Vaihingen2D(split="foo") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): Vaihingen2D(str(tmp_path)) def test_plot(self, dataset: Vaihingen2D) -> None: diff --git a/tests/datasets/test_vhr10.py b/tests/datasets/test_vhr10.py index ce69db7ef81..805b84a3117 100644 --- a/tests/datasets/test_vhr10.py +++ b/tests/datasets/test_vhr10.py @@ -16,7 +16,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import VHR10 +from torchgeo.datasets import VHR10, DatasetNotFoundError pytest.importorskip("pycocotools") @@ -90,7 +90,7 @@ def test_invalid_split(self) -> None: VHR10(split="train") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): VHR10(str(tmp_path)) def test_mock_missing_module( diff --git a/tests/datasets/test_western_usa_live_fuel_moisture.py b/tests/datasets/test_western_usa_live_fuel_moisture.py index 111da76dde6..3337965228e 100644 --- a/tests/datasets/test_western_usa_live_fuel_moisture.py +++ b/tests/datasets/test_western_usa_live_fuel_moisture.py @@ -10,7 +10,7 @@ import torch.nn as nn from pytest import MonkeyPatch -from torchgeo.datasets import WesternUSALiveFuelMoisture +from torchgeo.datasets import DatasetNotFoundError, WesternUSALiveFuelMoisture class Collection: @@ -65,7 +65,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: WesternUSALiveFuelMoisture(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found in"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): WesternUSALiveFuelMoisture(str(tmp_path)) def test_invalid_features(self, dataset: WesternUSALiveFuelMoisture) -> None: diff --git a/tests/datasets/test_xview2.py b/tests/datasets/test_xview2.py index 957de5a36d7..28292775a46 100644 --- a/tests/datasets/test_xview2.py +++ b/tests/datasets/test_xview2.py @@ -12,7 +12,7 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -from torchgeo.datasets import XView2 +from torchgeo.datasets import DatasetNotFoundError, XView2 class TestXView2: @@ -80,7 +80,7 @@ def test_invalid_split(self) -> None: XView2(split="foo") def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): XView2(str(tmp_path)) def test_plot(self, dataset: XView2) -> None: diff --git a/tests/datasets/test_zuericrop.py b/tests/datasets/test_zuericrop.py index 18b5a87eb65..27325672869 100644 --- a/tests/datasets/test_zuericrop.py +++ b/tests/datasets/test_zuericrop.py @@ -14,7 +14,7 @@ from pytest import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import ZueriCrop +from torchgeo.datasets import DatasetNotFoundError, ZueriCrop pytest.importorskip("h5py", minversion="3") @@ -79,10 +79,7 @@ def test_already_downloaded(self, dataset: ZueriCrop) -> None: ZueriCrop(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - err = "Dataset not found in `root` directory and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - with pytest.raises(RuntimeError, match=err): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): ZueriCrop(str(tmp_path)) def test_mock_missing_module( diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 8761e850e18..66555c7b978 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -18,6 +18,7 @@ from .l7irish import L7IrishDataModule from .l8biome import L8BiomeDataModule from .landcoverai import LandCoverAIDataModule +from .levircd import LEVIRCDPlusDataModule from .loveda import LoveDADataModule from .naip import NAIPChesapeakeDataModule from .nasa_marine_debris import NASAMarineDebrisDataModule @@ -56,6 +57,7 @@ "GID15DataModule", "InriaAerialImageLabelingDataModule", "LandCoverAIDataModule", + "LEVIRCDPlusDataModule", "LoveDADataModule", "NASAMarineDebrisDataModule", "OSCDDataModule", diff --git a/torchgeo/datamodules/eurosat.py b/torchgeo/datamodules/eurosat.py index b4267cfc50a..ccf2a90f691 100644 --- a/torchgeo/datamodules/eurosat.py +++ b/torchgeo/datamodules/eurosat.py @@ -10,41 +10,37 @@ from ..datasets import EuroSAT, EuroSAT100 from .geo import NonGeoDataModule -MEAN = torch.tensor( - [ - 1354.40546513, - 1118.24399958, - 1042.92983953, - 947.62620298, - 1199.47283961, - 1999.79090914, - 2369.22292565, - 2296.82608323, - 732.08340178, - 12.11327804, - 1819.01027855, - 1118.92391149, - 2594.14080798, - ] -) - -STD = torch.tensor( - [ - 245.71762908, - 333.00778264, - 395.09249139, - 593.75055589, - 566.4170017, - 861.18399006, - 1086.63139075, - 1117.98170791, - 404.91978886, - 4.77584468, - 1002.58768311, - 761.30323499, - 1231.58581042, - ] -) +MEAN = { + "B01": 1354.40546513, + "B02": 1118.24399958, + "B03": 1042.92983953, + "B04": 947.62620298, + "B05": 1199.47283961, + "B06": 1999.79090914, + "B07": 2369.22292565, + "B08": 2296.82608323, + "B8A": 732.08340178, + "B09": 12.11327804, + "B10": 1819.01027855, + "B11": 1118.92391149, + "B12": 2594.14080798, +} + +STD = { + "B01": 245.71762908, + "B02": 333.00778264, + "B03": 395.09249139, + "B04": 593.75055589, + "B05": 566.4170017, + "B06": 861.18399006, + "B07": 1086.63139075, + "B08": 1117.98170791, + "B8A": 404.91978886, + "B09": 4.77584468, + "B10": 1002.58768311, + "B11": 761.30323499, + "B12": 1231.58581042, +} class EuroSATDataModule(NonGeoDataModule): @@ -55,9 +51,6 @@ class EuroSATDataModule(NonGeoDataModule): .. versionadded:: 0.2 """ - mean = MEAN - std = STD - def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: @@ -71,6 +64,10 @@ def __init__( """ super().__init__(EuroSAT, batch_size, num_workers, **kwargs) + bands = kwargs.get("bands", EuroSAT.all_band_names) + self.mean = torch.tensor([MEAN[b] for b in bands]) + self.std = torch.tensor([STD[b] for b in bands]) + class EuroSAT100DataModule(NonGeoDataModule): """LightningDataModule implementation for the EuroSAT100 dataset. @@ -80,9 +77,6 @@ class EuroSAT100DataModule(NonGeoDataModule): .. versionadded:: 0.5 """ - mean = MEAN - std = STD - def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: @@ -95,3 +89,7 @@ def __init__( :class:`~torchgeo.datasets.EuroSAT100`. """ super().__init__(EuroSAT100, batch_size, num_workers, **kwargs) + + bands = kwargs.get("bands", EuroSAT.all_band_names) + self.mean = torch.tensor([MEAN[b] for b in bands]) + self.std = torch.tensor([STD[b] for b in bands]) diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py new file mode 100644 index 00000000000..b021d8c860b --- /dev/null +++ b/torchgeo/datamodules/levircd.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""LEVIR-CD+ datamodule.""" + +from typing import Any, Union + +import kornia.augmentation as K + +from torchgeo.datamodules.utils import dataset_split +from torchgeo.samplers.utils import _to_tuple + +from ..datasets import LEVIRCDPlus +from ..transforms import AugmentationSequential +from ..transforms.transforms import _RandomNCrop +from .geo import NonGeoDataModule + + +class LEVIRCDPlusDataModule(NonGeoDataModule): + """LightningDataModule implementation for the LEVIR-CD+ dataset. + + Uses the train/test splits from the dataset and further splits + the train split into train/val splits. + + .. versionadded:: 0.6 + """ + + def __init__( + self, + batch_size: int = 8, + patch_size: Union[tuple[int, int], int] = 256, + val_split_pct: float = 0.2, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a new LEVIRCDPlusDataModule instance. + + Args: + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures. + val_split_pct: Percentage of the dataset to use as a validation set. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.LEVIRCDPlus`. + """ + super().__init__(LEVIRCDPlus, 1, num_workers, **kwargs) + + self.patch_size = _to_tuple(patch_size) + self.val_split_pct = val_split_pct + + self.aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + _RandomNCrop(self.patch_size, batch_size), + data_keys=["image1", "image2", "mask"], + ) + + def setup(self, stage: str) -> None: + """Set up datasets. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + if stage in ["fit", "validate"]: + self.dataset = LEVIRCDPlus(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + self.dataset, val_pct=self.val_split_pct + ) + if stage in ["test"]: + self.test_dataset = LEVIRCDPlus(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 748c4038091..19f34677065 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -7,7 +7,6 @@ import kornia.augmentation as K import torch -from einops import repeat from ..datasets import OSCD from ..samplers.utils import _to_tuple @@ -16,6 +15,38 @@ from .geo import NonGeoDataModule from .utils import dataset_split +MEAN = { + "B01": 1583.0741, + "B02": 1374.3202, + "B03": 1294.1616, + "B04": 1325.6158, + "B05": 1478.7408, + "B06": 1933.0822, + "B07": 2166.0608, + "B08": 2076.4868, + "B8A": 2306.0652, + "B09": 690.9814, + "B10": 16.2360, + "B11": 2080.3347, + "B12": 1524.6930, +} + +STD = { + "B01": 52.1937, + "B02": 83.4168, + "B03": 105.6966, + "B04": 151.1401, + "B05": 147.4615, + "B06": 115.9289, + "B07": 123.1974, + "B08": 114.6483, + "B8A": 141.4530, + "B09": 73.2758, + "B10": 4.8368, + "B11": 213.4821, + "B12": 179.4793, +} + class OSCDDataModule(NonGeoDataModule): """LightningDataModule implementation for the OSCD dataset. @@ -26,42 +57,6 @@ class OSCDDataModule(NonGeoDataModule): .. versionadded:: 0.2 """ - mean = torch.tensor( - [ - 1583.0741, - 1374.3202, - 1294.1616, - 1325.6158, - 1478.7408, - 1933.0822, - 2166.0608, - 2076.4868, - 2306.0652, - 690.9814, - 16.2360, - 2080.3347, - 1524.6930, - ] - ) - - std = torch.tensor( - [ - 52.1937, - 83.4168, - 105.6966, - 151.1401, - 147.4615, - 115.9289, - 123.1974, - 114.6483, - 141.4530, - 73.2758, - 4.8368, - 213.4821, - 179.4793, - ] - ) - def __init__( self, batch_size: int = 64, @@ -86,19 +81,14 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.bands = kwargs.get("bands", "all") - if self.bands == "rgb": - self.mean = self.mean[[3, 2, 1]] - self.std = self.std[[3, 2, 1]] - - # Change detection, 2 images from different times - self.mean = repeat(self.mean, "c -> (t c)", t=2) - self.std = repeat(self.std, "c -> (t c)", t=2) + self.bands = kwargs.get("bands", OSCD.all_bands) + self.mean = torch.tensor([MEAN[b] for b in self.bands]) + self.std = torch.tensor([STD[b] for b in self.bands]) self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=["image", "mask"], + data_keys=["image1", "image2", "mask"], ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index abe83cdb374..56ed093996a 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -117,6 +117,7 @@ from .usavars import USAVars from .utils import ( BoundingBox, + DatasetNotFoundError, concat_samples, merge_samples, stack_samples, @@ -255,4 +256,6 @@ "random_grid_cell_assignment", "roi_split", "time_series_split", + # Errors + "DatasetNotFoundError", ) diff --git a/torchgeo/datasets/advance.py b/torchgeo/datasets/advance.py index 5d5684d06b3..3618db0fa9b 100644 --- a/torchgeo/datasets/advance.py +++ b/torchgeo/datasets/advance.py @@ -15,7 +15,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import download_and_extract_archive +from .utils import DatasetNotFoundError, download_and_extract_archive class ADVANCE(NonGeoDataset): @@ -101,8 +101,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. """ self.root = root self.transforms = transforms @@ -112,10 +111,7 @@ def __init__( self._download() if not self._check_integrity(): - raise RuntimeError( - "Dataset not found or corrupted. " - + "You can use download=True to download it" - ) + raise DatasetNotFoundError(self) self.files = self._load_files(self.root) self.classes = sorted({f["cls"] for f in self.files}) @@ -218,11 +214,7 @@ def _check_integrity(self) -> bool: return True def _download(self) -> None: - """Download the dataset and extract it. - - Raises: - AssertionError: if the checksum of split.py does not match - """ + """Download the dataset and extract it.""" if self._check_integrity(): print("Files already downloaded and verified") return diff --git a/torchgeo/datasets/agb_live_woody_density.py b/torchgeo/datasets/agb_live_woody_density.py index ec9ae90bd67..6c5959a0660 100644 --- a/torchgeo/datasets/agb_live_woody_density.py +++ b/torchgeo/datasets/agb_live_woody_density.py @@ -13,7 +13,7 @@ from rasterio.crs import CRS from .geo import RasterDataset -from .utils import download_url +from .utils import DatasetNotFoundError, download_url class AbovegroundLiveWoodyBiomassDensity(RasterDataset): @@ -44,10 +44,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset): is_image = False - url = ( - "https://opendata.arcgis.com/api/v3/datasets/3e8736c8866b458687" - "e00d40c9f00bce_0/downloads/data?format=geojson&spatialRefId=4326" - ) + url = "https://opendata.arcgis.com/api/v3/datasets/e4bdbe8d6d8d4e32ace7d36a4aec7b93_0/downloads/data?format=geojson&spatialRefId=4326" # noqa: E501 base_filename = "Aboveground_Live_Woody_Biomass_Density.geojson" @@ -80,7 +77,7 @@ def __init__( cache: if True, cache file handle to speed up repeated sampling Raises: - FileNotFoundError: if no files are found in ``paths`` + DatasetNotFoundError: If dataset is not found and *download* is False. .. versionchanged:: 0.5 *root* was renamed to *paths*. @@ -93,22 +90,14 @@ def __init__( super().__init__(paths, crs, res, transforms=transforms, cache=cache) def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if dataset is missing - """ + """Verify the integrity of the dataset.""" # Check if the extracted files already exist if self.files: return # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.paths}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() @@ -123,7 +112,7 @@ def _download(self) -> None: for item in content["features"]: download_url( - item["properties"]["download"], + item["properties"]["Mg_px_1_download"], self.paths, item["properties"]["tile_id"] + ".tif", ) diff --git a/torchgeo/datasets/astergdem.py b/torchgeo/datasets/astergdem.py index 305c6bf873c..99215f774b4 100644 --- a/torchgeo/datasets/astergdem.py +++ b/torchgeo/datasets/astergdem.py @@ -10,6 +10,7 @@ from rasterio.crs import CRS from .geo import RasterDataset +from .utils import DatasetNotFoundError class AsterGDEM(RasterDataset): @@ -65,8 +66,7 @@ def __init__( cache: if True, cache file handle to speed up repeated sampling Raises: - FileNotFoundError: if no files are found in ``paths`` - RuntimeError: if dataset is missing + DatasetNotFoundError: If dataset is not found. .. versionchanged:: 0.5 *root* was renamed to *paths*. @@ -78,20 +78,12 @@ def __init__( super().__init__(paths, crs, res, transforms=transforms, cache=cache) def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if dataset is missing - """ + """Verify the integrity of the dataset.""" # Check if the extracted files already exists if self.files: return - raise RuntimeError( - f"Dataset not found in `root={self.paths}` " - "either specify a different `root` directory or make sure you " - "have manually downloaded dataset tiles as suggested in the documentation." - ) + raise DatasetNotFoundError(self) def plot( self, diff --git a/torchgeo/datasets/benin_cashews.py b/torchgeo/datasets/benin_cashews.py index 4066fc6af2b..dccf29cedf0 100644 --- a/torchgeo/datasets/benin_cashews.py +++ b/torchgeo/datasets/benin_cashews.py @@ -18,7 +18,12 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import ( + DatasetNotFoundError, + check_integrity, + download_radiant_mlhub_collection, + extract_archive, +) # TODO: read geospatial information from stac.json files @@ -198,7 +203,7 @@ def __init__( verbose: if True, print messages when new tiles are loaded Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails + DatasetNotFoundError: If dataset is not found and *download* is False. """ self._validate_bands(bands) @@ -214,10 +219,7 @@ def __init__( self._download(api_key) if not self._check_integrity(): - raise RuntimeError( - "Dataset not found or corrupted. " - + "You can use download=True to download it" - ) + raise DatasetNotFoundError(self) # Calculate the indices that we will use over all tiles self.chips_metadata = [] diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index 0194fbbe106..9a127248a8b 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -17,7 +17,12 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import download_url, extract_archive, sort_sentinel2_bands +from .utils import ( + DatasetNotFoundError, + download_url, + extract_archive, + sort_sentinel2_bands, +) class BigEarthNet(NonGeoDataset): @@ -285,6 +290,9 @@ def __init__( entry and returns a transformed version download: if True, download dataset and store it in the root directory checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.splits_metadata assert bands in ["s1", "s2", "all"] @@ -434,11 +442,7 @@ def _load_target(self, index: int) -> Tensor: return target def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" keys = ["s1", "s2"] if self.bands == "all" else [self.bands] urls = [self.metadata[k]["url"] for k in keys] md5s = [self.metadata[k]["md5"] for k in keys] @@ -478,11 +482,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - "Dataset not found in `root` directory and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download and extract the dataset for url, filename, md5 in zip(urls, filenames, md5s): diff --git a/torchgeo/datasets/biomassters.py b/torchgeo/datasets/biomassters.py index 1288710bd23..970c5594950 100644 --- a/torchgeo/datasets/biomassters.py +++ b/torchgeo/datasets/biomassters.py @@ -16,7 +16,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import percentile_normalization +from .utils import DatasetNotFoundError, percentile_normalization class BioMassters(NonGeoDataset): @@ -75,8 +75,9 @@ def __init__( as_time_series: whether or not to return all available time-steps or just a single one for a given target location - RuntimeError: + Raises: AssertionError: if ``split`` or ``sensors`` is invalid + DatasetNotFoundError: If dataset is not found. """ self.root = root @@ -212,7 +213,7 @@ def _verify(self) -> None: if all(exists): return - raise RuntimeError(f"Dataset not found in `root={self.root}`.") + raise DatasetNotFoundError(self) def plot( self, diff --git a/torchgeo/datasets/cbf.py b/torchgeo/datasets/cbf.py index f2fe2b064f6..d3010e44dce 100644 --- a/torchgeo/datasets/cbf.py +++ b/torchgeo/datasets/cbf.py @@ -12,7 +12,7 @@ from rasterio.crs import CRS from .geo import VectorDataset -from .utils import check_integrity, download_and_extract_archive +from .utils import DatasetNotFoundError, check_integrity, download_and_extract_archive class CanadianBuildingFootprints(VectorDataset): @@ -81,9 +81,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - FileNotFoundError: if no files are found in ``root`` - RuntimeError: if ``download=False`` and data is not found, or - ``checksum=True`` and checksums don't match + DatasetNotFoundError: If dataset is not found and *download* is False. .. versionchanged:: 0.5 *root* was renamed to *paths*. @@ -95,10 +93,7 @@ def __init__( self._download() if not self._check_integrity(): - raise RuntimeError( - "Dataset not found or corrupted. " - + "You can use download=True to download it" - ) + raise DatasetNotFoundError(self) super().__init__(paths, crs, res, transforms) diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py index 39a7a105325..bad87db572a 100644 --- a/torchgeo/datasets/cdl.py +++ b/torchgeo/datasets/cdl.py @@ -13,7 +13,7 @@ from rasterio.crs import CRS from .geo import RasterDataset -from .utils import BoundingBox, download_url, extract_archive +from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive class CDL(RasterDataset): @@ -234,8 +234,7 @@ def __init__( Raises: AssertionError: if ``years`` or ``classes`` are invalid - FileNotFoundError: if no files are found in ``paths`` - RuntimeError: if ``download=False`` but dataset is missing or checksum fails + DatasetNotFoundError: If dataset is not found and *download* is False. .. versionadded:: 0.5 The *years* and *classes* parameters. @@ -286,11 +285,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: return sample def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted files already exist if self.files: return @@ -313,11 +308,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.paths}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 5dd3acb1b97..17b0ee8e74a 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -24,7 +24,7 @@ from torch import Tensor from .geo import GeoDataset, RasterDataset -from .utils import BoundingBox, download_url, extract_archive +from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive class Chesapeake(RasterDataset, abc.ABC): @@ -112,8 +112,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - FileNotFoundError: if no files are found in ``paths`` - RuntimeError: if ``download=False`` but dataset is missing or checksum fails + DatasetNotFoundError: If dataset is not found and *download* is False. .. versionchanged:: 0.5 *root* was renamed to *paths*. @@ -138,11 +137,7 @@ def __init__( super().__init__(paths, crs, res, transforms=transforms, cache=cache) def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted file already exists if self.files: return @@ -155,11 +150,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.paths}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() @@ -562,9 +553,8 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - FileNotFoundError: if no files are found in ``root`` - RuntimeError: if ``download=False`` but dataset is missing or checksum fails AssertionError: if ``splits`` or ``layers`` are not valid + DatasetNotFoundError: If dataset is not found and *download* is False. """ for split in splits: assert split in self.splits @@ -694,11 +684,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: return sample def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" def exists(filename: str) -> bool: return os.path.exists(os.path.join(self.root, filename)) @@ -719,11 +705,7 @@ def exists(filename: str) -> bool: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() diff --git a/torchgeo/datasets/cloud_cover.py b/torchgeo/datasets/cloud_cover.py index 94541ce4c0e..63a6433f630 100644 --- a/torchgeo/datasets/cloud_cover.py +++ b/torchgeo/datasets/cloud_cover.py @@ -16,7 +16,12 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import ( + DatasetNotFoundError, + check_integrity, + download_radiant_mlhub_collection, + extract_archive, +) # TODO: read geospatial information from stac.json files @@ -123,7 +128,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails + DatasetNotFoundError: If dataset is not found and *download* is False. """ self.root = root self.split = split @@ -137,10 +142,7 @@ def __init__( self._download(api_key) if not self._check_integrity(): - raise RuntimeError( - "Dataset not found or corrupted. " - + "You can use download=True to download it" - ) + raise DatasetNotFoundError(self) self.chip_paths = self._load_collections() @@ -331,9 +333,6 @@ def _download(self, api_key: Optional[str] = None) -> None: Args: api_key: a RadiantEarth MLHub API key to use for downloading the dataset - - Raises: - RuntimeError: if download doesn't work correctly or checksums don't match """ if self._check_integrity(): print("Files already downloaded and verified") diff --git a/torchgeo/datasets/cms_mangrove_canopy.py b/torchgeo/datasets/cms_mangrove_canopy.py index df02764705f..ac42c8d1ead 100644 --- a/torchgeo/datasets/cms_mangrove_canopy.py +++ b/torchgeo/datasets/cms_mangrove_canopy.py @@ -11,7 +11,7 @@ from rasterio.crs import CRS from .geo import RasterDataset -from .utils import check_integrity, extract_archive +from .utils import DatasetNotFoundError, check_integrity, extract_archive class CMSGlobalMangroveCanopy(RasterDataset): @@ -192,9 +192,8 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - FileNotFoundError: if no files are found in ``paths`` - RuntimeError: if dataset is missing or checksum fails AssertionError: if country or measurement arg are not str or invalid + DatasetNotFoundError: If dataset is not found. .. versionchanged:: 0.5 *root* was renamed to *paths*. @@ -225,11 +224,7 @@ def __init__( super().__init__(paths, crs, res, transforms=transforms, cache=cache) def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted files already exist if self.files: return @@ -243,11 +238,7 @@ def _verify(self) -> None: self._extract() return - raise RuntimeError( - f"Dataset not found in `root={self.paths}` " - "either specify a different `root` directory or make sure you " - "have manually downloaded the dataset as instructed in the documentation." - ) + raise DatasetNotFoundError(self) def _extract(self) -> None: """Extract the dataset.""" diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py index 9c60e01c302..0e0518502ad 100644 --- a/torchgeo/datasets/cowc.py +++ b/torchgeo/datasets/cowc.py @@ -16,7 +16,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, download_and_extract_archive +from .utils import DatasetNotFoundError, check_integrity, download_and_extract_archive class COWC(NonGeoDataset, abc.ABC): @@ -81,8 +81,7 @@ def __init__( Raises: AssertionError: if ``split`` argument is invalid - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in ["train", "test"] @@ -95,10 +94,7 @@ def __init__( self._download() if not self._check_integrity(): - raise RuntimeError( - "Dataset not found or corrupted. " - + "You can use download=True to download it" - ) + raise DatasetNotFoundError(self) self.images = [] self.targets = [] diff --git a/torchgeo/datasets/cv4a_kenya_crop_type.py b/torchgeo/datasets/cv4a_kenya_crop_type.py index aa1caba5134..1dc6a2c36a7 100644 --- a/torchgeo/datasets/cv4a_kenya_crop_type.py +++ b/torchgeo/datasets/cv4a_kenya_crop_type.py @@ -16,7 +16,12 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import ( + DatasetNotFoundError, + check_integrity, + download_radiant_mlhub_collection, + extract_archive, +) # TODO: read geospatial information from stac.json files @@ -141,7 +146,7 @@ def __init__( verbose: if True, print messages when new tiles are loaded Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails + DatasetNotFoundError: If dataset is not found and *download* is False. """ self._validate_bands(bands) @@ -157,10 +162,7 @@ def __init__( self._download(api_key) if not self._check_integrity(): - raise RuntimeError( - "Dataset not found or corrupted. " - + "You can use download=True to download it" - ) + raise DatasetNotFoundError(self) # Calculate the indices that we will use over all tiles self.chips_metadata = [] @@ -390,9 +392,6 @@ def _download(self, api_key: Optional[str] = None) -> None: Args: api_key: a RadiantEarth MLHub API key to use for downloading the dataset - - Raises: - RuntimeError: if download doesn't work correctly or checksums don't match """ if self._check_integrity(): print("Files already downloaded and verified") diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index 040754a9633..b0cc6f8f1c4 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -16,7 +16,12 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import ( + DatasetNotFoundError, + check_integrity, + download_radiant_mlhub_collection, + extract_archive, +) class TropicalCyclone(NonGeoDataset): @@ -86,7 +91,7 @@ def __init__( Raises: AssertionError: if ``split`` argument is invalid - RuntimeError: if ``download=False`` but dataset is missing or checksum fails + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.md5s @@ -99,10 +104,7 @@ def __init__( self._download(api_key) if not self._check_integrity(): - raise RuntimeError( - "Dataset not found or corrupted. " - + "You can use download=True to download it" - ) + raise DatasetNotFoundError(self) output_dir = "_".join([self.collection_id, split, "source"]) filename = os.path.join(root, output_dir, "collection.json") @@ -206,9 +208,6 @@ def _download(self, api_key: Optional[str] = None) -> None: Args: api_key: a RadiantEarth MLHub API key to use for downloading the dataset - - Raises: - RuntimeError: if download doesn't work correctly or checksums don't match """ if self._check_integrity(): print("Files already downloaded and verified") diff --git a/torchgeo/datasets/deepglobelandcover.py b/torchgeo/datasets/deepglobelandcover.py index 694da07f3d1..233c70cd049 100644 --- a/torchgeo/datasets/deepglobelandcover.py +++ b/torchgeo/datasets/deepglobelandcover.py @@ -15,6 +15,7 @@ from .geo import NonGeoDataset from .utils import ( + DatasetNotFoundError, check_integrity, draw_semantic_segmentation_masks, extract_archive, @@ -102,6 +103,9 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + DatasetNotFoundError: If dataset is not found. """ assert split in self.splits self.root = root @@ -195,11 +199,7 @@ def _load_target(self, index: int) -> Tensor: return tensor def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if checksum fails or the dataset is not downloaded - """ + """Verify the integrity of the dataset.""" # Check if the files already exist if os.path.exists(os.path.join(self.root, self.data_root)): return @@ -213,11 +213,7 @@ def _verify(self) -> None: extract_archive(filepath) return - # Check if the user requested to download the dataset - raise RuntimeError( - "Dataset not found in `root`, either specify a different" - + " `root` directory or manually download the dataset to this directory." - ) + raise DatasetNotFoundError(self) def plot( self, diff --git a/torchgeo/datasets/dfc2022.py b/torchgeo/datasets/dfc2022.py index fd76d2710cb..5268ea4f0e5 100644 --- a/torchgeo/datasets/dfc2022.py +++ b/torchgeo/datasets/dfc2022.py @@ -18,7 +18,12 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, extract_archive, percentile_normalization +from .utils import ( + DatasetNotFoundError, + check_integrity, + extract_archive, + percentile_normalization, +) class DFC2022(NonGeoDataset): @@ -153,6 +158,7 @@ def __init__( Raises: AssertionError: if ``split`` is invalid + DatasetNotFoundError: If dataset is not found. """ assert split in self.metadata self.root = root @@ -258,11 +264,7 @@ def _load_target(self, path: str) -> Tensor: return tensor def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if checksum fails or the dataset is not downloaded - """ + """Verify the integrity of the dataset.""" # Check if the files already exist exists = [] for split_info in self.metadata.values(): @@ -288,11 +290,7 @@ def _verify(self) -> None: if all(exists): return - # Check if the user requested to download the dataset - raise RuntimeError( - "Dataset not found in `root` directory, either specify a different" - + " `root` directory or manually download the dataset to this directory." - ) + raise DatasetNotFoundError(self) def plot( self, diff --git a/torchgeo/datasets/eddmaps.py b/torchgeo/datasets/eddmaps.py index 2eeee0aeaa4..94e409c4d07 100644 --- a/torchgeo/datasets/eddmaps.py +++ b/torchgeo/datasets/eddmaps.py @@ -12,7 +12,7 @@ from rasterio.crs import CRS from .geo import GeoDataset -from .utils import BoundingBox, disambiguate_timestamp +from .utils import BoundingBox, DatasetNotFoundError, disambiguate_timestamp class EDDMapS(GeoDataset): @@ -48,7 +48,7 @@ def __init__(self, root: str = "data") -> None: root: root directory where dataset can be found Raises: - FileNotFoundError: if no files are found in ``root`` + DatasetNotFoundError: If dataset is not found. """ super().__init__() @@ -56,7 +56,7 @@ def __init__(self, root: str = "data") -> None: filepath = os.path.join(root, "mappings.csv") if not os.path.exists(filepath): - raise FileNotFoundError(f"Dataset not found in `root={self.root}`") + raise DatasetNotFoundError(self) # Read CSV file data = pd.read_csv( diff --git a/torchgeo/datasets/enviroatlas.py b/torchgeo/datasets/enviroatlas.py index 551c142eaf4..f6842738abb 100644 --- a/torchgeo/datasets/enviroatlas.py +++ b/torchgeo/datasets/enviroatlas.py @@ -22,7 +22,7 @@ from rasterio.crs import CRS from .geo import GeoDataset -from .utils import BoundingBox, download_url, extract_archive +from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive class EnviroAtlas(GeoDataset): @@ -278,9 +278,8 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - FileNotFoundError: if no files are found in ``root`` - RuntimeError: if ``download=False`` but dataset is missing or checksum fails AssertionError: if ``splits`` or ``layers`` are not valid + DatasetNotFoundError: If dataset is not found and *download* is False. """ for split in splits: assert split in self.splits @@ -412,11 +411,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: return sample def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" def exists(filename: str) -> bool: return os.path.exists(os.path.join(self.root, "enviroatlas_lotp", filename)) @@ -432,11 +427,7 @@ def exists(filename: str) -> bool: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() diff --git a/torchgeo/datasets/esri2020.py b/torchgeo/datasets/esri2020.py index 6b875b8e040..2d26f24565a 100644 --- a/torchgeo/datasets/esri2020.py +++ b/torchgeo/datasets/esri2020.py @@ -13,7 +13,7 @@ from rasterio.crs import CRS from .geo import RasterDataset -from .utils import download_url, extract_archive +from .utils import DatasetNotFoundError, download_url, extract_archive class Esri2020(RasterDataset): @@ -91,8 +91,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - FileNotFoundError: if no files are found in ``paths`` - RuntimeError: if ``download=False`` but dataset is missing or checksum fails + DatasetNotFoundError: If dataset is not found and *download* is False. .. versionchanged:: 0.5 *root* was renamed to *paths*. @@ -106,11 +105,7 @@ def __init__( super().__init__(paths, crs, res, transforms=transforms, cache=cache) def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted file already exists if self.files: return @@ -124,11 +119,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.paths}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() diff --git a/torchgeo/datasets/etci2021.py b/torchgeo/datasets/etci2021.py index e837f1c2d85..7dfa50fb2ab 100644 --- a/torchgeo/datasets/etci2021.py +++ b/torchgeo/datasets/etci2021.py @@ -15,7 +15,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import download_and_extract_archive +from .utils import DatasetNotFoundError, download_and_extract_archive class ETCI2021(NonGeoDataset): @@ -98,8 +98,7 @@ def __init__( Raises: AssertionError: if ``split`` argument is invalid - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.metadata.keys() @@ -112,10 +111,7 @@ def __init__( self._download() if not self._check_integrity(): - raise RuntimeError( - "Dataset not found or corrupted. " - + "You can use download=True to download it" - ) + raise DatasetNotFoundError(self) self.files = self._load_files(self.root, self.split) @@ -243,11 +239,7 @@ def _check_integrity(self) -> bool: return True def _download(self) -> None: - """Download the dataset and extract it. - - Raises: - AssertionError: if the checksum of split.py does not match - """ + """Download the dataset and extract it.""" if self._check_integrity(): print("Files already downloaded and verified") return diff --git a/torchgeo/datasets/eudem.py b/torchgeo/datasets/eudem.py index 35313dae075..ce82500ee7d 100644 --- a/torchgeo/datasets/eudem.py +++ b/torchgeo/datasets/eudem.py @@ -13,7 +13,7 @@ from rasterio.crs import CRS from .geo import RasterDataset -from .utils import check_integrity, extract_archive +from .utils import DatasetNotFoundError, check_integrity, extract_archive class EUDEM(RasterDataset): @@ -105,7 +105,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - FileNotFoundError: if no files are found in ``paths`` + DatasetNotFoundError: If dataset is not found. .. versionchanged:: 0.5 *root* was renamed to *paths*. @@ -118,11 +118,7 @@ def __init__( super().__init__(paths, crs, res, transforms=transforms, cache=cache) def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted file already exists if self.files: return @@ -138,11 +134,7 @@ def _verify(self) -> None: extract_archive(zipfile) return - raise RuntimeError( - f"Dataset not found in `root={self.paths}` " - "either specify a different `root` directory or make sure you " - "have manually downloaded the dataset as suggested in the documentation." - ) + raise DatasetNotFoundError(self) def plot( self, diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index 3fd5c05ff92..319f10e1c65 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -14,7 +14,13 @@ from torch import Tensor from .geo import NonGeoClassificationDataset -from .utils import check_integrity, download_url, extract_archive, rasterio_loader +from .utils import ( + DatasetNotFoundError, + check_integrity, + download_url, + extract_archive, + rasterio_loader, +) class EuroSAT(NonGeoClassificationDataset): @@ -31,16 +37,16 @@ class EuroSAT(NonGeoClassificationDataset): Dataset classes: - * Industrial Buildings - * Residential Buildings * Annual Crop - * Permanent Crop - * River - * Sea and Lake + * Forest * Herbaceous Vegetation * Highway + * Industrial Buildings * Pasture - * Forest + * Permanent Crop + * Residential Buildings + * River + * SeaLake This dataset uses the train/val/test splits defined in the "In-domain representation learning for remote sensing" paper: @@ -73,18 +79,6 @@ class EuroSAT(NonGeoClassificationDataset): "val": "95de90f2aa998f70a3b2416bfe0687b4", "test": "7ae5ab94471417b6e315763121e67c5f", } - classes = [ - "Industrial Buildings", - "Residential Buildings", - "Annual Crop", - "Permanent Crop", - "River", - "Sea and Lake", - "Herbaceous Vegetation", - "Highway", - "Pasture", - "Forest", - ] all_band_names = ( "B01", @@ -128,8 +122,7 @@ def __init__( Raises: AssertionError: if ``split`` argument is invalid - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. .. versionadded:: 0.3 The *bands* parameter. @@ -192,11 +185,7 @@ def _check_integrity(self) -> bool: return integrity def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the files already exist filepath = os.path.join(self.root, self.base_dir) if os.path.exists(filepath): @@ -209,11 +198,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - "Dataset not found in `root` directory and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download and extract the dataset self._download() diff --git a/torchgeo/datasets/fair1m.py b/torchgeo/datasets/fair1m.py index 7f1a9a20d74..24cc7b97d2f 100644 --- a/torchgeo/datasets/fair1m.py +++ b/torchgeo/datasets/fair1m.py @@ -17,7 +17,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, download_url, extract_archive +from .utils import DatasetNotFoundError, check_integrity, download_url, extract_archive def parse_pascal_voc(path: str) -> dict[str, Any]: @@ -244,8 +244,7 @@ def __init__( Raises: AssertionError: if ``split`` argument is invalid - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found. .. versionchanged:: 0.5 Added *split* and *download* parameters. @@ -329,11 +328,7 @@ def _load_target( return boxes, labels_tensor def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if checksum fails or the dataset is not found - """ + """Verify the integrity of the dataset.""" # Check if the directories already exist exists = [] for directory in self.directories[self.split]: @@ -362,18 +357,10 @@ def _verify(self) -> None: self._download() return - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) def _download(self) -> None: - """Download the dataset and extract it. - - Raises: - RuntimeError: if download doesn't work correctly or checksums don't match - """ + """Download the dataset and extract it.""" paths = self.paths[self.split] urls = self.urls[self.split] md5s = self.md5s[self.split] diff --git a/torchgeo/datasets/fire_risk.py b/torchgeo/datasets/fire_risk.py index 76accf24acd..d51a4384f81 100644 --- a/torchgeo/datasets/fire_risk.py +++ b/torchgeo/datasets/fire_risk.py @@ -11,7 +11,7 @@ from torch import Tensor from .geo import NonGeoClassificationDataset -from .utils import download_url, extract_archive +from .utils import DatasetNotFoundError, download_url, extract_archive class FireRisk(NonGeoClassificationDataset): @@ -84,7 +84,7 @@ def __init__( Raises: AssertionError: if ``split`` argument is invalid - RuntimeError: if ``download=False`` but dataset is missing or checksum fails + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.splits self.root = root @@ -98,11 +98,7 @@ def __init__( ) def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the files already exist path = os.path.join(self.root, self.directory) if os.path.exists(path): @@ -116,11 +112,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download and extract the dataset self._download() diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py index 7628256e004..b74f3bcd3bc 100644 --- a/torchgeo/datasets/forestdamage.py +++ b/torchgeo/datasets/forestdamage.py @@ -17,7 +17,12 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, download_and_extract_archive, extract_archive +from .utils import ( + DatasetNotFoundError, + check_integrity, + download_and_extract_archive, + extract_archive, +) def parse_pascal_voc(path: str) -> dict[str, Any]: @@ -119,8 +124,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. """ self.root = root self.transforms = transforms @@ -237,21 +241,13 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - "Dataset not found in `root` directory, either specify a different" - + " `root` directory or manually download " - + "the dataset to this directory." - ) + raise DatasetNotFoundError(self) # else download the dataset self._download() def _download(self) -> None: - """Download the dataset and extract it. - - Raises: - AssertionError: if the checksum does not match - """ + """Download the dataset and extract it.""" download_and_extract_archive( self.url, self.root, diff --git a/torchgeo/datasets/gbif.py b/torchgeo/datasets/gbif.py index 8f14720c8ea..a34cc8ba685 100644 --- a/torchgeo/datasets/gbif.py +++ b/torchgeo/datasets/gbif.py @@ -14,7 +14,7 @@ from rasterio.crs import CRS from .geo import GeoDataset -from .utils import BoundingBox +from .utils import BoundingBox, DatasetNotFoundError def _disambiguate_timestamps( @@ -86,7 +86,7 @@ def __init__(self, root: str = "data") -> None: root: root directory where dataset can be found Raises: - FileNotFoundError: if no files are found in ``root`` + DatasetNotFoundError: If dataset is not found. """ super().__init__() @@ -94,7 +94,7 @@ def __init__(self, root: str = "data") -> None: files = glob.glob(os.path.join(root, "**.csv")) if not files: - raise FileNotFoundError(f"Dataset not found in `root={self.root}`") + raise DatasetNotFoundError(self) # Read tab-delimited CSV file data = pd.read_table( diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 7c5c68ac654..34db093e08c 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -9,6 +9,7 @@ import os import re import sys +import warnings from collections.abc import Iterable, Sequence from typing import Any, Callable, Optional, Union, cast @@ -29,7 +30,14 @@ from torchvision.datasets import ImageFolder from torchvision.datasets.folder import default_loader as pil_loader -from .utils import BoundingBox, concat_samples, disambiguate_timestamp, merge_samples +from .utils import ( + BoundingBox, + DatasetNotFoundError, + concat_samples, + disambiguate_timestamp, + merge_samples, + path_is_vsi, +) class GeoDataset(Dataset[dict[str, Any]], abc.ABC): @@ -298,8 +306,14 @@ def files(self) -> set[str]: if os.path.isdir(path): pathname = os.path.join(path, "**", self.filename_glob) files |= set(glob.iglob(pathname, recursive=True)) - else: + elif os.path.isfile(path) or path_is_vsi(path): files.add(path) + else: + warnings.warn( + f"Could not find any relevant files for provided path '{path}'. " + f"Path was ignored.", + UserWarning, + ) return files @@ -377,7 +391,7 @@ def __init__( cache: if True, cache file handle to speed up repeated sampling Raises: - FileNotFoundError: if no files are found in ``paths`` + DatasetNotFoundError: If dataset is not found. .. versionchanged:: 0.5 *root* was renamed to *paths*. @@ -425,13 +439,7 @@ def __init__( i += 1 if i == 0: - msg = ( - f"No {self.__class__.__name__} data was found " - f"in `paths={self.paths!r}'`" - ) - if self.bands: - msg += f" with `bands={self.bands}`" - raise FileNotFoundError(msg) + raise DatasetNotFoundError(self) if not self.separate_files: self.band_indexes = None @@ -593,7 +601,7 @@ def __init__( rasterized into the mask Raises: - FileNotFoundError: if no files are found in ``root`` + DatasetNotFoundError: If dataset is not found. .. versionadded:: 0.4 The *label_name* parameter. @@ -629,8 +637,7 @@ def __init__( i += 1 if i == 0: - msg = f"No {self.__class__.__name__} data was found in `root='{paths}'`" - raise FileNotFoundError(msg) + raise DatasetNotFoundError(self) self._crs = crs self._res = res diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py index cb153c0e3c6..6fc2b520181 100644 --- a/torchgeo/datasets/gid15.py +++ b/torchgeo/datasets/gid15.py @@ -15,7 +15,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import download_and_extract_archive +from .utils import DatasetNotFoundError, download_and_extract_archive class GID15(NonGeoDataset): @@ -105,8 +105,7 @@ def __init__( Raises: AssertionError: if ``split`` argument is invalid - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.splits @@ -119,10 +118,7 @@ def __init__( self._download() if not self._check_integrity(): - raise RuntimeError( - "Dataset not found or corrupted. " - + "You can use download=True to download it" - ) + raise DatasetNotFoundError(self) self.files = self._load_files(self.root, self.split) @@ -226,11 +222,7 @@ def _check_integrity(self) -> bool: return True def _download(self) -> None: - """Download the dataset and extract it. - - Raises: - AssertionError: if the checksum of split.py does not match - """ + """Download the dataset and extract it.""" if self._check_integrity(): print("Files already downloaded and verified") return diff --git a/torchgeo/datasets/globbiomass.py b/torchgeo/datasets/globbiomass.py index 7fe6428c8ec..c9da83b9bec 100644 --- a/torchgeo/datasets/globbiomass.py +++ b/torchgeo/datasets/globbiomass.py @@ -14,7 +14,7 @@ from rasterio.crs import CRS from .geo import RasterDataset -from .utils import BoundingBox, check_integrity, extract_archive +from .utils import BoundingBox, DatasetNotFoundError, check_integrity, extract_archive class GlobBiomass(RasterDataset): @@ -142,9 +142,8 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - FileNotFoundError: if no files are found in ``paths`` - RuntimeError: if dataset is missing or checksum fails AssertionError: if measurement argument is invalid, or not a str + DatasetNotFoundError: If dataset is not found. .. versionchanged:: 0.5 *root* was renamed to *paths*. @@ -204,11 +203,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: return sample def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted file already exists if self.files: return @@ -224,11 +219,7 @@ def _verify(self) -> None: extract_archive(zipfile) return - raise RuntimeError( - f"Dataset not found in `root={self.paths}` " - "either specify a different `root` directory or make sure you " - "have manually downloaded the dataset as suggested in the documentation." - ) + raise DatasetNotFoundError(self) def plot( self, diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py index 2d3833cf1b3..916ad1af71f 100644 --- a/torchgeo/datasets/idtrees.py +++ b/torchgeo/datasets/idtrees.py @@ -20,7 +20,7 @@ from torchvision.utils import draw_bounding_boxes from .geo import NonGeoDataset -from .utils import download_url, extract_archive +from .utils import DatasetNotFoundError, download_url, extract_archive class IDTReeS(NonGeoDataset): @@ -166,6 +166,7 @@ def __init__( Raises: ImportError: if laspy is not installed + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in ["train", "test"] assert task in ["task1", "task2"] @@ -443,11 +444,7 @@ def _filter_boxes( return boxes, labels def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" url = self.metadata[self.split]["url"] md5 = self.metadata[self.split]["md5"] filename = self.metadata[self.split]["filename"] @@ -469,11 +466,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - "Dataset not found in `root` directory and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download and extract the dataset download_url( diff --git a/torchgeo/datasets/inaturalist.py b/torchgeo/datasets/inaturalist.py index da59081ca8f..6838a5cdf4b 100644 --- a/torchgeo/datasets/inaturalist.py +++ b/torchgeo/datasets/inaturalist.py @@ -12,7 +12,7 @@ from rasterio.crs import CRS from .geo import GeoDataset -from .utils import BoundingBox, disambiguate_timestamp +from .utils import BoundingBox, DatasetNotFoundError, disambiguate_timestamp class INaturalist(GeoDataset): @@ -40,7 +40,7 @@ def __init__(self, root: str = "data") -> None: root: root directory where dataset can be found Raises: - FileNotFoundError: if no files are found in ``root`` + DatasetNotFoundError: If dataset is not found. """ super().__init__() @@ -48,7 +48,7 @@ def __init__(self, root: str = "data") -> None: files = glob.glob(os.path.join(root, "**.csv")) if not files: - raise FileNotFoundError(f"Dataset not found in `root={self.root}`") + raise DatasetNotFoundError(self) # Read CSV file data = pd.read_csv( diff --git a/torchgeo/datasets/inria.py b/torchgeo/datasets/inria.py index 415e72c0533..b3ab0a6fd9c 100644 --- a/torchgeo/datasets/inria.py +++ b/torchgeo/datasets/inria.py @@ -16,7 +16,12 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, extract_archive, percentile_normalization +from .utils import ( + DatasetNotFoundError, + check_integrity, + extract_archive, + percentile_normalization, +) class InriaAerialImageLabeling(NonGeoDataset): @@ -73,7 +78,7 @@ def __init__( Raises: AssertionError: if ``split`` is invalid - RuntimeError: if dataset is missing + DatasetNotFoundError: If dataset is not found. """ self.root = root assert split in {"train", "val", "test"} @@ -185,11 +190,7 @@ def _verify(self) -> None: archive_path = os.path.join(self.root, self.filename) md5_hash = self.md5 if self.checksum else None if not os.path.isfile(archive_path): - raise RuntimeError( - f"Dataset not found in `root={self.root}` " - "either specify a different `root` directory " - "or download the dataset to this directory" - ) + raise DatasetNotFoundError(self) if not check_integrity(archive_path, md5_hash): raise RuntimeError("Dataset corrupted") print("Extracting...") diff --git a/torchgeo/datasets/l7irish.py b/torchgeo/datasets/l7irish.py index 5f1fd760199..04805e85c3a 100644 --- a/torchgeo/datasets/l7irish.py +++ b/torchgeo/datasets/l7irish.py @@ -14,7 +14,7 @@ from torch import Tensor from .geo import RasterDataset -from .utils import BoundingBox, download_url, extract_archive +from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive class L7Irish(RasterDataset): @@ -116,8 +116,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. """ self.paths = paths self.download = download @@ -130,11 +129,7 @@ def __init__( ) def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted files already exist if self.files: return @@ -148,11 +143,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.paths}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() diff --git a/torchgeo/datasets/l8biome.py b/torchgeo/datasets/l8biome.py index 68444df0ce8..e42eaf1b2a7 100644 --- a/torchgeo/datasets/l8biome.py +++ b/torchgeo/datasets/l8biome.py @@ -14,7 +14,7 @@ from torch import Tensor from .geo import RasterDataset -from .utils import BoundingBox, download_url, extract_archive +from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive class L8Biome(RasterDataset): @@ -115,8 +115,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. """ self.paths = paths self.download = download @@ -129,11 +128,7 @@ def __init__( ) def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted files already exist if self.files: return @@ -147,11 +142,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.paths}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index 1a5543dbdd8..8dec50b562e 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -20,7 +20,13 @@ from torch.utils.data import Dataset from .geo import NonGeoDataset, RasterDataset -from .utils import BoundingBox, download_url, extract_archive, working_dir +from .utils import ( + BoundingBox, + DatasetNotFoundError, + download_url, + extract_archive, + working_dir, +) class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC): @@ -84,8 +90,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. """ self.root = root self.download = download @@ -99,11 +104,7 @@ def __init__( self._verify() def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" if self._verify_data(): return @@ -115,11 +116,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() @@ -234,8 +231,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. """ LandCoverAIBase.__init__(self, root, download, checksum) RasterDataset.__init__(self, root, crs, res, transforms=transforms, cache=cache) @@ -319,8 +315,7 @@ def __init__( Raises: AssertionError: if ``split`` argument is invalid - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in ["train", "val", "test"] diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py index 097b3a36a94..c3b7a48f1d3 100644 --- a/torchgeo/datasets/landsat.py +++ b/torchgeo/datasets/landsat.py @@ -79,7 +79,7 @@ def __init__( cache: if True, cache file handle to speed up repeated sampling Raises: - FileNotFoundError: if no files are found in ``paths`` + DatasetNotFoundError: If dataset is not found and *download* is False. .. versionchanged:: 0.5 *root* was renamed to *paths*. diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index e03206f803d..3b5028dcef9 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -15,7 +15,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import download_and_extract_archive +from .utils import DatasetNotFoundError, download_and_extract_archive class LEVIRCDPlus(NonGeoDataset): @@ -72,8 +72,7 @@ def __init__( Raises: AssertionError: if ``split`` argument is invalid - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.splits @@ -86,10 +85,7 @@ def __init__( self._download() if not self._check_integrity(): - raise RuntimeError( - "Dataset not found or corrupted. " - + "You can use download=True to download it" - ) + raise DatasetNotFoundError(self) self.files = self._load_files(self.root, self.directory, self.split) @@ -106,9 +102,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image1 = self._load_image(files["image1"]) image2 = self._load_image(files["image2"]) mask = self._load_target(files["mask"]) - - image = torch.stack(tensors=[image1, image2], dim=0) - sample = {"image": image, "mask": mask} + sample = {"image1": image1, "image2": image2, "mask": mask} if self.transforms is not None: sample = self.transforms(sample) @@ -158,7 +152,7 @@ def _load_image(self, path: str) -> Tensor: filename = os.path.join(path) with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor @@ -193,11 +187,7 @@ def _check_integrity(self) -> bool: return True def _download(self) -> None: - """Download the dataset and extract it. - - Raises: - AssertionError: if the checksum of split.py does not match - """ + """Download the dataset and extract it.""" if self._check_integrity(): print("Files already downloaded and verified") return @@ -227,20 +217,34 @@ def plot( .. versionadded:: 0.2 """ - image1, image2, mask = (sample["image"][0], sample["image"][1], sample["mask"]) ncols = 3 + def get_rgb(img: Tensor) -> "np.typing.NDArray[np.uint8]": + rgb_img = img.permute(1, 2, 0).float().numpy() + per02 = np.percentile(rgb_img, 2) + per98 = np.percentile(rgb_img, 98) + delta = per98 - per02 + epsilon = 1e-7 + norm_img: "np.typing.NDArray[np.uint8]" = ( + np.clip((rgb_img - per02) / (delta + epsilon), 0, 1) * 255 + ).astype(np.uint8) + return norm_img + + image1 = get_rgb(sample["image1"]) + image2 = get_rgb(sample["image2"]) + mask = sample["mask"].numpy() + if "prediction" in sample: prediction = sample["prediction"] ncols += 1 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5)) - axs[0].imshow(image1.permute(1, 2, 0)) + axs[0].imshow(image1) axs[0].axis("off") - axs[1].imshow(image2.permute(1, 2, 0)) + axs[1].imshow(image2) axs[1].axis("off") - axs[2].imshow(mask) + axs[2].imshow(mask, cmap="gray") axs[2].axis("off") if "prediction" in sample: diff --git a/torchgeo/datasets/loveda.py b/torchgeo/datasets/loveda.py index 4a817283790..9c7e2aaff4e 100644 --- a/torchgeo/datasets/loveda.py +++ b/torchgeo/datasets/loveda.py @@ -15,7 +15,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import download_and_extract_archive +from .utils import DatasetNotFoundError, download_and_extract_archive class LoveDA(NonGeoDataset): @@ -109,10 +109,8 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - AssertionError: if ``split`` argument is invalid - AssertionError: if ``scene`` argument is invalid - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + AssertionError: if ``split`` or ``scene`` arguments are invalid + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.splits assert set(scene).intersection( @@ -139,10 +137,7 @@ def __init__( self._download() if not self._check_integrity(): - raise RuntimeError( - "Dataset not found at root directory or corrupted. " - + "You can use download=True to download it" - ) + raise DatasetNotFoundError(self) self.files = self._load_files(self.scene_paths, self.split) @@ -249,11 +244,7 @@ def _check_integrity(self) -> bool: return True def _download(self) -> None: - """Download the dataset and extract it. - - Raises: - AssertionError: if the checksum of split.py does not match - """ + """Download the dataset and extract it.""" if self._check_integrity(): print("Files already downloaded and verified") return diff --git a/torchgeo/datasets/mapinwild.py b/torchgeo/datasets/mapinwild.py index 252fef09600..5eaa426d230 100644 --- a/torchgeo/datasets/mapinwild.py +++ b/torchgeo/datasets/mapinwild.py @@ -18,6 +18,7 @@ from .geo import NonGeoDataset from .utils import ( + DatasetNotFoundError, check_integrity, download_url, extract_archive, @@ -128,6 +129,7 @@ def __init__( Raises: AssertionError: if ``split`` argument is invalid + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in ["train", "validation", "test"] @@ -227,9 +229,6 @@ def _verify(self, url: str, md5: Optional[str] = None) -> None: Args: url: url to the file md5: md5 of the file to be verified - - Raises: - RuntimeError: if dataset is not found """ modality_folder_name = url.split("/")[-1] mod_fold_no_ext = modality_folder_name.split(".")[0] @@ -252,11 +251,7 @@ def _verify(self, url: str, md5: Optional[str] = None) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` directory and `download=False`, " # noqa: E501 - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download(url, md5) diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py index 48742e0884d..ed9a9c156ea 100644 --- a/torchgeo/datasets/millionaid.py +++ b/torchgeo/datasets/millionaid.py @@ -15,7 +15,7 @@ from torchgeo.datasets import NonGeoDataset -from .utils import check_integrity, extract_archive +from .utils import DatasetNotFoundError, check_integrity, extract_archive class MillionAID(NonGeoDataset): @@ -205,7 +205,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if dataset is not found + DatasetNotFoundError: If dataset is not found. """ self.root = root self.transforms = transforms @@ -326,11 +326,7 @@ def _verify(self) -> None: extract_archive(filepath) return - raise RuntimeError( - f"Dataset not found in `root={self.root}` directory, either " - "specify a different `root` directory or manually download " - "the dataset to this directory." - ) + raise DatasetNotFoundError(self) def plot( self, diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index c685b999255..a1637f46e7d 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -15,7 +15,12 @@ from torchvision.utils import draw_bounding_boxes from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import ( + DatasetNotFoundError, + check_integrity, + download_radiant_mlhub_collection, + extract_archive, +) class NASAMarineDebris(NonGeoDataset): @@ -77,6 +82,9 @@ def __init__( api_key: a RadiantEarth MLHub API key to use for downloading the dataset checksum: if True, check the MD5 of the downloaded files (may be slow) verbose: if True, print messages when new tiles are loaded + + Raises: + DatasetNotFoundError: If dataset is not found and *download* is False. """ self.root = root self.transforms = transforms @@ -175,11 +183,7 @@ def _load_files(self) -> list[dict[str, str]]: return files def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the files already exist exists = [ os.path.exists(os.path.join(self.root, directory)) @@ -205,11 +209,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - "Dataset not found in `root` directory and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download and extract the dataset for collection_id in self.collection_ids: diff --git a/torchgeo/datasets/nlcd.py b/torchgeo/datasets/nlcd.py index de2b0c7bd92..6c1243538f3 100644 --- a/torchgeo/datasets/nlcd.py +++ b/torchgeo/datasets/nlcd.py @@ -14,7 +14,7 @@ from rasterio.crs import CRS from .geo import RasterDataset -from .utils import BoundingBox, download_url, extract_archive +from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive class NLCD(RasterDataset): @@ -136,8 +136,7 @@ def __init__( Raises: AssertionError: if ``years`` or ``classes`` are invalid - FileNotFoundError: if no files are found in ``paths`` - RuntimeError: if ``download=False`` but dataset is missing or checksum fails + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert set(years) <= self.md5s.keys(), ( "NLCD data product only exists for the following years: " @@ -182,11 +181,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: return sample def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted files already exist if self.files: return @@ -208,11 +203,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.paths}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() diff --git a/torchgeo/datasets/openbuildings.py b/torchgeo/datasets/openbuildings.py index 32c66d73199..e05ea93220c 100644 --- a/torchgeo/datasets/openbuildings.py +++ b/torchgeo/datasets/openbuildings.py @@ -23,7 +23,7 @@ from rtree.index import Index, Property from .geo import VectorDataset -from .utils import BoundingBox, check_integrity +from .utils import BoundingBox, DatasetNotFoundError, check_integrity class OpenBuildings(VectorDataset): @@ -224,7 +224,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - FileNotFoundError: if no files are found in ``root`` + DatasetNotFoundError: If dataset is not found. .. versionchanged:: 0.5 *root* was renamed to *paths*. @@ -284,9 +284,7 @@ def __init__( i += 1 if i == 0: - raise FileNotFoundError( - f"No {self.__class__.__name__} data was found in '{self.paths}'" - ) + raise DatasetNotFoundError(self) self._crs = crs self._source_crs = source_crs @@ -395,12 +393,7 @@ def _wkt_fiona_geom_transform(self, x: str) -> dict[str, Any]: return transformed def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if dataset is missing or checksum fails - FileNotFoundError: if metadata file is not found in root - """ + """Verify the integrity of the dataset.""" # Check if the zip files have already been downloaded and checksum assert isinstance(self.paths, str) pathname = os.path.join(self.paths, self.zipfile_glob) @@ -414,18 +407,7 @@ def _verify(self) -> None: if i != 0: return - # check if the metadata file has been downloaded - if not os.path.exists(os.path.join(self.paths, self.meta_data_filename)): - raise FileNotFoundError( - f"Meta data file {self.meta_data_filename} " - f"not found in in `root={self.paths}`." - ) - - raise RuntimeError( - f"Dataset not found in `root={self.paths}` " - "either specify a different `root` directory or make sure you " - "have manually downloaded the dataset as suggested in the documentation." - ) + raise DatasetNotFoundError(self) def plot( self, diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index e24fbe4612c..e60c4de6214 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -17,6 +17,7 @@ from .geo import NonGeoDataset from .utils import ( + DatasetNotFoundError, download_url, draw_semantic_segmentation_masks, extract_archive, @@ -78,11 +79,29 @@ class OSCD(NonGeoDataset): colormap = ["blue"] + all_bands = ( + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B10", + "B11", + "B12", + ) + + rgb_bands = ("B04", "B03", "B02") + def __init__( self, root: str = "data", split: str = "train", - bands: str = "all", + bands: Sequence[str] = all_bands, transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, @@ -99,15 +118,15 @@ def __init__( Raises: AssertionError: if ``split`` argument is invalid - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.splits - assert bands in ["rgb", "all"] + assert set(bands) <= set(self.all_bands) + self.bands = bands + self.all_band_indices = [self.all_bands.index(b) for b in self.bands] self.root = root self.split = split - self.bands = bands self.transforms = transforms self.download = download self.checksum = checksum @@ -129,9 +148,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image1 = self._load_image(files["images1"]) image2 = self._load_image(files["images2"]) mask = self._load_target(str(files["mask"])) - - image = torch.cat([image1, image2]) - sample = {"image": image, "mask": mask} + sample = {"image1": image1, "image2": image2, "mask": mask} if self.transforms is not None: sample = self.transforms(sample) @@ -170,8 +187,8 @@ def get_image_paths(ind: int) -> list[str]: ) images1, images2 = get_image_paths(1), get_image_paths(2) - if self.bands == "rgb": - images1, images2 = images1[1:4][::-1], images2[1:4][::-1] + images1 = [images1[i] for i in self.all_band_indices] + images2 = [images2[i] for i in self.all_band_indices] with open(os.path.join(images_root, region, "dates.txt")) as f: dates = tuple( @@ -225,11 +242,7 @@ def _load_target(self, path: str) -> Tensor: return tensor def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted files already exist pathname = os.path.join(self.root, "**", self.filename_glob) for fname in glob.iglob(pathname, recursive=True): @@ -244,11 +257,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() @@ -287,13 +296,21 @@ def plot( Returns: a matplotlib Figure with the rendered sample + + Raises: + ValueError: If *bands* does not include all RGB bands. """ ncols = 2 - rgb_inds = [3, 2, 1] if self.bands == "all" else [0, 1, 2] + try: + rgb_indices = [self.bands.index(band) for band in self.rgb_bands] + except ValueError as e: + raise ValueError( + "RGB bands must be present to use `plot` with S2 imagery." + ) from e def get_masked(img: Tensor) -> "np.typing.NDArray[np.uint8]": - rgb_img = img[rgb_inds].float().numpy() + rgb_img = img[rgb_indices].float().numpy() per02 = np.percentile(rgb_img, 2) per98 = np.percentile(rgb_img, 98) rgb_img = (np.clip((rgb_img - per02) / (per98 - per02), 0, 1) * 255).astype( @@ -307,9 +324,8 @@ def get_masked(img: Tensor) -> "np.typing.NDArray[np.uint8]": ) return array - idx = sample["image"].shape[0] // 2 - image1 = get_masked(sample["image"][:idx]) - image2 = get_masked(sample["image"][idx:]) + image1 = get_masked(sample["image1"]) + image2 = get_masked(sample["image2"]) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) axs[0].imshow(image1) axs[0].axis("off") diff --git a/torchgeo/datasets/pastis.py b/torchgeo/datasets/pastis.py index d8091443550..f551408b039 100644 --- a/torchgeo/datasets/pastis.py +++ b/torchgeo/datasets/pastis.py @@ -16,7 +16,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, download_url, extract_archive +from .utils import DatasetNotFoundError, check_integrity, download_url, extract_archive class PASTIS(NonGeoDataset): @@ -149,6 +149,9 @@ def __init__( entry and returns a transformed version download: if True, download dataset and store it in the root directory checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert set(folds) <= set(range(6)) assert bands in ["s1a", "s1d", "s2"] @@ -308,11 +311,7 @@ def _load_files(self) -> list[dict[str, str]]: return files def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the directory already exists path = os.path.join(self.root, self.directory) if os.path.exists(path): @@ -328,11 +327,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download and extract the dataset self._download() diff --git a/torchgeo/datasets/patternnet.py b/torchgeo/datasets/patternnet.py index 51329e779f8..876f14dd59b 100644 --- a/torchgeo/datasets/patternnet.py +++ b/torchgeo/datasets/patternnet.py @@ -11,7 +11,7 @@ from torch import Tensor from .geo import NonGeoClassificationDataset -from .utils import download_url, extract_archive +from .utils import DatasetNotFoundError, download_url, extract_archive class PatternNet(NonGeoClassificationDataset): @@ -96,6 +96,9 @@ def __init__( entry and returns a transformed version download: if True, download dataset and store it in the root directory checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + DatasetNotFoundError: If dataset is not found and *download* is False. """ self.root = root self.download = download @@ -104,11 +107,7 @@ def __init__( super().__init__(root=os.path.join(root, self.directory), transforms=transforms) def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the files already exist filepath = os.path.join(self.root, self.directory) if os.path.exists(filepath): @@ -122,11 +121,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - "Dataset not found in `root` directory and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download and extract the dataset self._download() diff --git a/torchgeo/datasets/potsdam.py b/torchgeo/datasets/potsdam.py index 5443c807d22..782fd7ce87d 100644 --- a/torchgeo/datasets/potsdam.py +++ b/torchgeo/datasets/potsdam.py @@ -16,6 +16,7 @@ from .geo import NonGeoDataset from .utils import ( + DatasetNotFoundError, check_integrity, draw_semantic_segmentation_masks, extract_archive, @@ -133,6 +134,10 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If *split* is invalid. + DatasetNotFoundError: If dataset is not found. """ assert split in self.splits self.root = root @@ -209,11 +214,7 @@ def _load_target(self, index: int) -> Tensor: return tensor def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if checksum fails or the dataset is not downloaded - """ + """Verify the integrity of the dataset.""" # Check if the files already exist if os.path.exists(os.path.join(self.root, self.image_root)): return @@ -233,11 +234,7 @@ def _verify(self) -> None: if all(exists): return - # Check if the user requested to download the dataset - raise RuntimeError( - "Dataset not found in `root` directory, either specify a different" - + " `root` directory or manually download the dataset to this directory." - ) + raise DatasetNotFoundError(self) def plot( self, diff --git a/torchgeo/datasets/reforestree.py b/torchgeo/datasets/reforestree.py index a7bfa54329f..9ee8a82d617 100644 --- a/torchgeo/datasets/reforestree.py +++ b/torchgeo/datasets/reforestree.py @@ -17,7 +17,12 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, download_and_extract_archive, extract_archive +from .utils import ( + DatasetNotFoundError, + check_integrity, + download_and_extract_archive, + extract_archive, +) class ReforesTree(NonGeoDataset): @@ -78,8 +83,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. """ self.root = root self.transforms = transforms @@ -173,11 +177,7 @@ def _load_target(self, filepath: str) -> tuple[Tensor, ...]: return boxes, labels, agb def _verify(self) -> None: - """Checks the integrity of the dataset structure. - - Raises: - RuntimeError: if dataset is not found in root or is corrupted - """ + """Checks the integrity of the dataset structure.""" filepaths = [os.path.join(self.root, dir) for dir in ["tiles", "mapping"]] if all([os.path.exists(filepath) for filepath in filepaths]): return @@ -191,21 +191,13 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # else download the dataset self._download() def _download(self) -> None: - """Download the dataset and extract it. - - Raises: - AssertionError: if the checksum does not match - """ + """Download the dataset and extract it.""" download_and_extract_archive( self.url, self.root, diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py index 35f02a7c665..07fdfd272a8 100644 --- a/torchgeo/datasets/resisc45.py +++ b/torchgeo/datasets/resisc45.py @@ -12,7 +12,7 @@ from torch import Tensor from .geo import NonGeoClassificationDataset -from .utils import download_url, extract_archive +from .utils import DatasetNotFoundError, download_url, extract_archive class RESISC45(NonGeoClassificationDataset): @@ -89,7 +89,6 @@ class RESISC45(NonGeoClassificationDataset): If you use this dataset in your research, please cite the following paper: * https://doi.org/10.1109/jproc.2017.2675998 - """ url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv" @@ -108,53 +107,6 @@ class RESISC45(NonGeoClassificationDataset): "val": "a0770cee4c5ca20b8c32bbd61e114805", "test": "3dda9e4988b47eb1de9f07993653eb08", } - classes = [ - "airplane", - "airport", - "baseball_diamond", - "basketball_court", - "beach", - "bridge", - "chaparral", - "church", - "circular_farmland", - "cloud", - "commercial_area", - "dense_residential", - "desert", - "forest", - "freeway", - "golf_course", - "ground_track_field", - "harbor", - "industrial_area", - "intersection", - "island", - "lake", - "meadow", - "medium_residential", - "mobile_home_park", - "mountain", - "overpass", - "palace", - "parking_lot", - "railway", - "railway_station", - "rectangular_farmland", - "river", - "roundabout", - "runway", - "sea_ice", - "ship", - "snowberg", - "sparse_residential", - "stadium", - "storage_tank", - "tennis_court", - "terrace", - "thermal_power_station", - "wetland", - ] def __init__( self, @@ -173,6 +125,9 @@ def __init__( entry and returns a transformed version download: if True, download dataset and store it in the root directory checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.splits self.root = root @@ -193,11 +148,7 @@ def __init__( ) def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the files already exist filepath = os.path.join(self.root, self.directory) if os.path.exists(filepath): @@ -211,11 +162,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - "Dataset not found in `root` directory and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download and extract the dataset self._download() diff --git a/torchgeo/datasets/rwanda_field_boundary.py b/torchgeo/datasets/rwanda_field_boundary.py index 8d40960f8da..32f6a1625e2 100644 --- a/torchgeo/datasets/rwanda_field_boundary.py +++ b/torchgeo/datasets/rwanda_field_boundary.py @@ -16,7 +16,12 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import ( + DatasetNotFoundError, + check_integrity, + download_radiant_mlhub_collection, + extract_archive, +) class RwandaFieldBoundary(NonGeoDataset): @@ -103,8 +108,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - or if ``download=True`` and ``api_key=None`` + DatasetNotFoundError: If dataset is not found and *download* is False. """ self._validate_bands(bands) assert split in self.splits @@ -200,11 +204,7 @@ def _validate_bands(self, bands: Sequence[str]) -> None: raise ValueError(f"'{band}' is an invalid band name.") def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the subdirectories already exist and have the correct number of files checks = [] for split, num_patches in self.number_of_patches_per_split.items(): @@ -236,21 +236,13 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download and extract the dataset self._download() def _download(self) -> None: - """Download the dataset and extract it. - - Raises: - RuntimeError: if download doesn't work correctly or checksums don't match - """ + """Download the dataset and extract it.""" for collection_id in self.collection_ids: download_radiant_mlhub_collection(collection_id, self.root, self.api_key) diff --git a/torchgeo/datasets/seasonet.py b/torchgeo/datasets/seasonet.py index a13b4897c32..6d61fa011f1 100644 --- a/torchgeo/datasets/seasonet.py +++ b/torchgeo/datasets/seasonet.py @@ -20,7 +20,12 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import download_url, extract_archive, percentile_normalization +from .utils import ( + DatasetNotFoundError, + download_url, + extract_archive, + percentile_normalization, +) class SeasoNet(NonGeoDataset): @@ -233,6 +238,9 @@ def __init__( entry and returns a transformed version download: if True, download dataset and store it in the root directory checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.splits assert set(seasons) <= self.all_seasons @@ -354,11 +362,7 @@ def _load_target(self, index: int) -> Tensor: return tensor def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if all files already exist if all( os.path.exists(os.path.join(self.root, file_info["name"])) @@ -378,12 +382,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if missing and not self.download: - raise RuntimeError( - f"{', '.join([m['name'] for m in missing])} not found in" - " `root={self.root}` and `download=False`, either specify a" - " different `root` directory or use `download=True`" - " to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download missing files for file_info in missing: diff --git a/torchgeo/datasets/seco.py b/torchgeo/datasets/seco.py index 29a9d32250c..e6abd0e1d2b 100644 --- a/torchgeo/datasets/seco.py +++ b/torchgeo/datasets/seco.py @@ -16,7 +16,12 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import download_url, extract_archive, percentile_normalization +from .utils import ( + DatasetNotFoundError, + download_url, + extract_archive, + percentile_normalization, +) class SeasonalContrastS2(NonGeoDataset): @@ -94,8 +99,7 @@ def __init__( Raises: AssertionError: if ``version`` argument is invalid - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert version in self.metadata.keys() assert seasons in range(5) @@ -183,11 +187,7 @@ def _load_patch(self, root: str, subdir: str) -> Tensor: return image def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted files already exist directory_path = os.path.join( self.root, self.metadata[self.version]["directory"] @@ -203,11 +203,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py index 9370fbf1d7a..1a0f812c368 100644 --- a/torchgeo/datasets/sen12ms.py +++ b/torchgeo/datasets/sen12ms.py @@ -15,7 +15,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, percentile_normalization +from .utils import DatasetNotFoundError, check_integrity, percentile_normalization class SEN12MS(NonGeoDataset): @@ -189,7 +189,7 @@ def __init__( Raises: AssertionError: if ``split`` argument is invalid - RuntimeError: if data is not found in ``root``, or checksums don't match + DatasetNotFoundError: If dataset is not found. """ assert split in ["train", "test"] @@ -204,12 +204,10 @@ def __init__( self.transforms = transforms self.checksum = checksum - if checksum: - if not self._check_integrity(): - raise RuntimeError("Dataset not found or corrupted.") - else: - if not self._check_integrity_light(): - raise RuntimeError("Dataset not found or corrupted.") + if ( + checksum and not self._check_integrity() + ) or not self._check_integrity_light(): + raise DatasetNotFoundError(self) with open(os.path.join(self.root, split + "_list.txt")) as f: self.ids = [line.rstrip() for line in f.readlines()] diff --git a/torchgeo/datasets/sentinel.py b/torchgeo/datasets/sentinel.py index 1c4e2423482..2c1eebe51a7 100644 --- a/torchgeo/datasets/sentinel.py +++ b/torchgeo/datasets/sentinel.py @@ -162,7 +162,7 @@ def __init__( Raises: AssertionError: if ``bands`` is invalid - FileNotFoundError: if no files are found in ``paths`` + DatasetNotFoundError: If dataset is not found. .. versionchanged:: 0.5 *root* was renamed to *paths*. @@ -317,7 +317,7 @@ def __init__( cache: if True, cache file handle to speed up repeated sampling Raises: - FileNotFoundError: if no files are found in ``paths`` + DatasetNotFoundError: If dataset is not found. .. versionchanged:: 0.5 *root* was renamed to *paths* diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py index 1242678c8b9..156b3f1568d 100644 --- a/torchgeo/datasets/skippd.py +++ b/torchgeo/datasets/skippd.py @@ -14,7 +14,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import download_url, extract_archive +from .utils import DatasetNotFoundError, download_url, extract_archive class SKIPPD(NonGeoDataset): @@ -91,8 +91,8 @@ def __init__( Raises: AssertionError: if ``task`` or ``split`` is invalid + DatasetNotFoundError: If dataset is not found and *download* is False. ImportError: if h5py is not installed - RuntimeError: if ``download=False`` but dataset is missing or checksum fails """ assert ( split in self.valid_splits @@ -202,11 +202,7 @@ def _load_features(self, index: int) -> dict[str, Union[str, Tensor]]: return features def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted files already exist pathname = os.path.join(self.root, self.data_file_name.format(self.task)) if os.path.exists(pathname): @@ -220,22 +216,14 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() self._extract() def _download(self) -> None: - """Download the dataset and extract it. - - Raises: - RuntimeError: if download doesn't work correctly or checksums don't match - """ + """Download the dataset and extract it.""" download_url( self.url.format(self.zipfile_name.format(self.task)), self.root, diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index d5623a6ab1c..dff7276c65a 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -14,7 +14,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, percentile_normalization +from .utils import DatasetNotFoundError, check_integrity, percentile_normalization class So2Sat(NonGeoDataset): @@ -208,7 +208,7 @@ def __init__( Raises: AssertionError: if ``split`` argument is invalid - RuntimeError: if data is not found in ``root``, or checksums don't match + DatasetNotFoundError: If dataset is not found. .. versionadded:: 0.3 The *bands* parameter. @@ -257,7 +257,7 @@ def __init__( self.fn = os.path.join(self.root, self.filenames_by_version[version][split]) if not self._check_integrity(): - raise RuntimeError("Dataset not found or corrupted.") + raise DatasetNotFoundError(self) with h5py.File(self.fn, "r") as f: self.size: int = f["label"].shape[0] diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index a4d1edb1ddd..19f10311261 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -35,7 +35,7 @@ class south_america_soybean(RasterDataset): date_format = "%Y" is_image = False - url = "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_" + url = "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2001.tif" md5s = { 2001: "2914b0af7590a0ca4dfa9ccefc99020f", diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index e1f00347dfe..c6780e1971c 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -26,6 +26,7 @@ from .geo import NonGeoDataset from .utils import ( + DatasetNotFoundError, check_integrity, download_radiant_mlhub_collection, download_radiant_mlhub_dataset, @@ -98,7 +99,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` but dataset is missing + DatasetNotFoundError: If dataset is not found and *download* is False. """ self.root = root self.image = image # For testing @@ -116,11 +117,7 @@ def __init__( if to_be_downloaded: if not download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use " - "`download=True` to automatically download the dataset." - ) + raise DatasetNotFoundError(self) else: self._download(to_be_downloaded, api_key) @@ -283,9 +280,6 @@ def _download(self, collections: list[str], api_key: Optional[str] = None) -> No Args: collections: Collections to be downloaded api_key: a RadiantEarth MLHub API key to use for downloading the dataset - - Raises: - RuntimeError: if download doesn't work correctly or checksums don't match """ for collection in collections: download_radiant_mlhub_collection(collection, self.root, api_key) @@ -421,7 +415,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` but dataset is missing + DatasetNotFoundError: If dataset is not found and *download* is False. """ collections = ["sn1_AOI_1_RIO"] assert image in {"rgb", "8band"} @@ -541,7 +535,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` but dataset is missing + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert image in {"MS", "PAN", "PS-MS", "PS-RGB"} super().__init__( @@ -664,7 +658,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` but dataset is missing + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert image in {"MS", "PAN", "PS-MS", "PS-RGB"} self.speed_mask = speed_mask @@ -909,7 +903,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` but dataset is missing + DatasetNotFoundError: If dataset is not found and *download* is False. """ collections = ["sn4_AOI_6_Atlanta"] assert image in {"MS", "PAN", "PS-RGBNIR"} @@ -1081,7 +1075,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` but dataset is missing + DatasetNotFoundError: If dataset is not found and *download* is False. """ super().__init__( root, @@ -1205,7 +1199,7 @@ def __init__( api_key: a RadiantEarth MLHub API key to use for downloading the dataset Raises: - RuntimeError: if ``download=False`` but dataset is missing + DatasetNotFoundError: If dataset is not found and *download* is False. """ self.root = root self.image = image # For testing @@ -1223,9 +1217,6 @@ def __download(self, api_key: Optional[str] = None) -> None: Args: api_key: a RadiantEarth MLHub API key to use for downloading the dataset - - Raises: - RuntimeError: if download doesn't work correctly or checksums don't match """ if os.path.exists( os.path.join( @@ -1307,7 +1298,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` but dataset is missing + DatasetNotFoundError: If dataset is not found and *download* is False. """ self.root = root self.split = split @@ -1326,11 +1317,7 @@ def __init__( if to_be_downloaded: if not download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use " - "`download=True` to automatically download the dataset." - ) + raise DatasetNotFoundError(self) else: self._download(to_be_downloaded, api_key) diff --git a/torchgeo/datasets/ssl4eo.py b/torchgeo/datasets/ssl4eo.py index 27fd5c11d41..fa486c4aa79 100644 --- a/torchgeo/datasets/ssl4eo.py +++ b/torchgeo/datasets/ssl4eo.py @@ -16,7 +16,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, download_url, extract_archive +from .utils import DatasetNotFoundError, check_integrity, download_url, extract_archive class SSL4EO(NonGeoDataset): @@ -180,7 +180,7 @@ def __init__( Raises: AssertionError: if any arguments are invalid - RuntimeError: if ``download=False`` but dataset is missing or checksum fails + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.metadata assert seasons in range(1, 5) @@ -234,11 +234,7 @@ def __len__(self) -> int: return len(self.scenes) def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted files already exist path = os.path.join(self.subdir, "00000*", "*", "all_bands.tif") if glob.glob(path): @@ -256,11 +252,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() @@ -430,7 +422,7 @@ def __init__( Raises: AssertionError: if ``split`` argument is invalid - RuntimeError: if dataset is missing or checksum fails + DatasetNotFoundError: If dataset is not found. """ assert split in self.metadata assert seasons in range(1, 5) @@ -483,11 +475,7 @@ def __len__(self) -> int: return 251079 def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted files already exist directory_path = os.path.join(self.root, self.split) if os.path.exists(directory_path): @@ -501,7 +489,7 @@ def _verify(self) -> None: if integrity: self._extract() else: - raise RuntimeError(f"Dataset not found in `root={self.root}`") + raise DatasetNotFoundError(self) def _extract(self) -> None: """Extract the dataset.""" diff --git a/torchgeo/datasets/ssl4eo_benchmark.py b/torchgeo/datasets/ssl4eo_benchmark.py index 904d7236aad..a2cd92867b7 100644 --- a/torchgeo/datasets/ssl4eo_benchmark.py +++ b/torchgeo/datasets/ssl4eo_benchmark.py @@ -17,7 +17,7 @@ from .cdl import CDL from .geo import NonGeoDataset from .nlcd import NLCD -from .utils import download_url, extract_archive +from .utils import DatasetNotFoundError, download_url, extract_archive class SSL4EOLBenchmark(NonGeoDataset): @@ -131,7 +131,7 @@ def __init__( Raises: AssertionError: if any arguments are invalid - RuntimeError: if ``download=False`` but dataset is missing or checksum fails + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert ( sensor in self.valid_sensors @@ -190,11 +190,7 @@ def __init__( self.ordinal_cmap[v] = torch.tensor(self.cmap[k]) def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted files already exist img_pathname = os.path.join(self.root, self.img_dir_name, "**", "all_bands.tif") exists = [] @@ -223,11 +219,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py index 5619697b932..5dca4e6d969 100644 --- a/torchgeo/datasets/sustainbench_crop_yield.py +++ b/torchgeo/datasets/sustainbench_crop_yield.py @@ -13,7 +13,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import download_url, extract_archive +from .utils import DatasetNotFoundError, download_url, extract_archive class SustainBenchCropYield(NonGeoDataset): @@ -78,7 +78,7 @@ def __init__( Raises: AssertionError: if ``countries`` contains invalid countries or if ``split`` is invalid - RuntimeError: if ``download=False`` but dataset is missing or checksum fails + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert set(countries).issubset( self.valid_countries @@ -186,11 +186,7 @@ def retrieve_collection(self) -> list[tuple[str, int]]: return collection def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted files already exist pathname = os.path.join(self.root, self.dir) if os.path.exists(pathname): @@ -204,22 +200,14 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() self._extract() def _download(self) -> None: - """Download the dataset and extract it. - - Raises: - RuntimeError: if download doesn't work correctly or checksums don't match - """ + """Download the dataset and extract it.""" download_url( self.url, self.root, diff --git a/torchgeo/datasets/ucmerced.py b/torchgeo/datasets/ucmerced.py index 0fc9551f0da..4a7867dc4a5 100644 --- a/torchgeo/datasets/ucmerced.py +++ b/torchgeo/datasets/ucmerced.py @@ -12,7 +12,7 @@ from torch import Tensor from .geo import NonGeoClassificationDataset -from .utils import check_integrity, download_url, extract_archive +from .utils import DatasetNotFoundError, check_integrity, download_url, extract_archive class UCMerced(NonGeoClassificationDataset): @@ -68,29 +68,6 @@ class UCMerced(NonGeoClassificationDataset): md5 = "5b7ec56793786b6dc8a908e8854ac0e4" base_dir = os.path.join("UCMerced_LandUse", "Images") - classes = [ - "agricultural", - "airplane", - "baseballdiamond", - "beach", - "buildings", - "chaparral", - "denseresidential", - "forest", - "freeway", - "golfcourse", - "harbor", - "intersection", - "mediumresidential", - "mobilehomepark", - "overpass", - "parkinglot", - "river", - "runway", - "sparseresidential", - "storagetanks", - "tenniscourt", - ] splits = ["train", "val", "test"] split_urls = { @@ -123,8 +100,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.splits self.root = root @@ -170,11 +146,7 @@ def _check_integrity(self) -> bool: return integrity def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the files already exist filepath = os.path.join(self.root, self.base_dir) if os.path.exists(filepath): @@ -187,11 +159,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - "Dataset not found in `root` directory and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download and extract the dataset self._download() @@ -237,6 +205,11 @@ def plot( .. versionadded:: 0.2 """ image = np.rollaxis(sample["image"].numpy(), 0, 3) + + # Normalize the image if the max value is greater than 1 + if image.max() > 1: + image = image.astype(np.float32) / 255.0 # Scale to [0, 1] + label = cast(int, sample["label"].item()) label_class = self.classes[label] diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 2b9f65b4d93..f33d268d4ce 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -17,7 +17,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import download_url, extract_archive +from .utils import DatasetNotFoundError, download_url, extract_archive class USAVars(NonGeoDataset): @@ -106,8 +106,7 @@ def __init__( Raises: AssertionError: if invalid labels are provided - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. """ self.root = root @@ -186,11 +185,7 @@ def _load_image(self, path: str) -> Tensor: return tensor def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted files already exist pathname = os.path.join(self.root, "uar") csv_pathname = os.path.join(self.root, "*.csv") @@ -208,11 +203,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) self._download() self._extract() diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 55d94abed6f..6446e7e64e1 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -23,11 +23,13 @@ import rasterio import torch from torch import Tensor +from torch.utils.data import Dataset from torchvision.datasets.utils import check_integrity, download_url from torchvision.utils import draw_segmentation_masks __all__ = ( "check_integrity", + "DatasetNotFoundError", "download_url", "download_and_extract_archive", "extract_archive", @@ -46,6 +48,49 @@ ) +class DatasetNotFoundError(FileNotFoundError): + """Raised when a dataset is requested but doesn't exist. + + .. versionadded:: 0.6 + """ + + def __init__(self, dataset: Dataset[object]) -> None: + """Intstantiate a new DatasetNotFoundError instance. + + Args: + dataset: The dataset that was requested. + """ + msg = "Dataset not found" + + if hasattr(dataset, "root"): + var = "root" + val = dataset.root + elif hasattr(dataset, "paths"): + var = "paths" + val = dataset.paths + else: + super().__init__(f"{msg}.") + return + + msg += f" in `{var}={val!r}` and " + + if hasattr(dataset, "download") and not dataset.download: + msg += "`download=False`" + else: + msg += "cannot be automatically downloaded" + + msg += f", either specify a different `{var}` or " + + if hasattr(dataset, "download") and not dataset.download: + msg += "use `download=True` to automatically" + else: + msg += "manually" + + msg += " download the dataset." + + super().__init__(msg) + + class _rarfile: class RarFile: def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -737,3 +782,27 @@ def percentile_normalization( (img - lower_percentile) / (upper_percentile - lower_percentile + 1e-5), 0, 1 ) return img_normalized + + +def path_is_vsi(path: str) -> bool: + """Checks if the given path is pointing to a Virtual File System. + + .. note:: + Does not check if the path exists, or if it is a dir or file. + + VSI can for instance be Cloud Storage Blobs or zip-archives. + They will start with a prefix indicating this. + For examples of these, see references for the two accepted syntaxes. + + * https://gdal.org/user/virtual_file_systems.html + * https://rasterio.readthedocs.io/en/latest/topics/datasets.html + + Args: + path: string representing a directory or file + + Returns: + True if path is on a virtual file system, else False + + .. versionadded:: 0.6 + """ + return "://" in path or path.startswith("/vsi") diff --git a/torchgeo/datasets/vaihingen.py b/torchgeo/datasets/vaihingen.py index 78370f31585..59ecfcda690 100644 --- a/torchgeo/datasets/vaihingen.py +++ b/torchgeo/datasets/vaihingen.py @@ -15,6 +15,7 @@ from .geo import NonGeoDataset from .utils import ( + DatasetNotFoundError, check_integrity, draw_semantic_segmentation_masks, extract_archive, @@ -132,6 +133,10 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If *split* is invalid. + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.splits self.root = root @@ -210,11 +215,7 @@ def _load_target(self, index: int) -> Tensor: return tensor def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if checksum fails or the dataset is not downloaded - """ + """Verify the integrity of the dataset.""" # Check if the files already exist if os.path.exists(os.path.join(self.root, self.image_root)): return @@ -234,11 +235,7 @@ def _verify(self) -> None: if all(exists): return - # Check if the user requested to download the dataset - raise RuntimeError( - "Dataset not found in `root` directory, either specify a different" - + " `root` directory or manually download the dataset to this directory." - ) + raise DatasetNotFoundError(self) def plot( self, diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index 83add564be1..db0807ee930 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -15,7 +15,12 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, download_and_extract_archive, download_url +from .utils import ( + DatasetNotFoundError, + check_integrity, + download_and_extract_archive, + download_url, +) def convert_coco_poly_to_mask( @@ -200,8 +205,7 @@ def __init__( Raises: AssertionError: if ``split`` argument is invalid ImportError: if ``split="positive"`` and pycocotools is not installed - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in ["positive", "negative"] @@ -214,10 +218,7 @@ def __init__( self._download() if not self._check_integrity(): - raise RuntimeError( - "Dataset not found or corrupted. " - + "You can use download=True to download it" - ) + raise DatasetNotFoundError(self) if split == "positive": # Must be installed to parse annotations file diff --git a/torchgeo/datasets/western_usa_live_fuel_moisture.py b/torchgeo/datasets/western_usa_live_fuel_moisture.py index bf383593112..d782cd2b50a 100644 --- a/torchgeo/datasets/western_usa_live_fuel_moisture.py +++ b/torchgeo/datasets/western_usa_live_fuel_moisture.py @@ -13,7 +13,11 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import download_radiant_mlhub_collection, extract_archive +from .utils import ( + DatasetNotFoundError, + download_radiant_mlhub_collection, + extract_archive, +) class WesternUSALiveFuelMoisture(NonGeoDataset): @@ -218,7 +222,7 @@ def __init__( Raises: AssertionError: if ``input_features`` contains invalid variable names - RuntimeError: if ``download=False`` but dataset is missing or checksum fails + DatasetNotFoundError: If dataset is not found and *download* is False. """ super().__init__() @@ -300,11 +304,7 @@ def _load_data(self) -> pd.DataFrame: return df def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted files already exist pathname = os.path.join(self.root, self.collection_id) if os.path.exists(pathname): @@ -318,11 +318,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() @@ -338,9 +334,6 @@ def _download(self, api_key: Optional[str] = None) -> None: Args: api_key: a RadiantEarth MLHub API key to use for downloading the dataset - - Raises: - RuntimeError: if download doesn't work correctly or checksums don't match """ download_radiant_mlhub_collection(self.collection_id, self.root, api_key) filename = os.path.join(self.root, self.collection_id) + ".tar.gz" diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 665d156db04..55eaa6735c8 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -15,7 +15,12 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, draw_semantic_segmentation_masks, extract_archive +from .utils import ( + DatasetNotFoundError, + check_integrity, + draw_semantic_segmentation_masks, + extract_archive, +) class XView2(NonGeoDataset): @@ -78,6 +83,10 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If *split* is invalid. + DatasetNotFoundError: If dataset is not found. """ assert split in self.metadata self.root = root @@ -181,11 +190,7 @@ def _load_target(self, path: str) -> Tensor: return tensor def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if checksum fails or the dataset is not downloaded - """ + """Verify the integrity of the dataset.""" # Check if the files already exist exists = [] for split_info in self.metadata.values(): @@ -214,11 +219,7 @@ def _verify(self) -> None: if all(exists): return - # Check if the user requested to download the dataset - raise RuntimeError( - "Dataset not found in `root` directory, either specify a different" - + " `root` directory or manually download the dataset to this directory." - ) + raise DatasetNotFoundError(self) def plot( self, diff --git a/torchgeo/datasets/zuericrop.py b/torchgeo/datasets/zuericrop.py index a2e3e22ab60..2047b121adc 100644 --- a/torchgeo/datasets/zuericrop.py +++ b/torchgeo/datasets/zuericrop.py @@ -13,7 +13,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import download_url, percentile_normalization +from .utils import DatasetNotFoundError, download_url, percentile_normalization class ZueriCrop(NonGeoDataset): @@ -81,8 +81,7 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + DatasetNotFoundError: If dataset is not found and *download* is False. """ self._validate_bands(bands) self.band_indices = torch.tensor( @@ -209,11 +208,7 @@ def _load_target(self, index: int) -> tuple[Tensor, Tensor, Tensor]: return masks, boxes, labels def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the files already exist exists = [] for filename in self.filenames: @@ -225,11 +220,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - "Dataset not found in `root` directory and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() @@ -251,6 +242,7 @@ def _validate_bands(self, bands: Sequence[str]) -> None: Args: bands: user-provided sequence of bands to load + Raises: AssertionError: if ``bands`` is not a sequence ValueError: if an invalid band name is provided diff --git a/torchgeo/trainers/base.py b/torchgeo/trainers/base.py index c8371c6ffa4..a549cbe3f7b 100644 --- a/torchgeo/trainers/base.py +++ b/torchgeo/trainers/base.py @@ -8,7 +8,6 @@ import lightning from lightning.pytorch import LightningModule -from lightning.pytorch.callbacks import Callback, EarlyStopping, ModelCheckpoint from torch.optim import AdamW from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -36,17 +35,6 @@ def __init__(self) -> None: self.configure_metrics() self.configure_models() - def configure_callbacks(self) -> list[Callback]: - """Initialize model-specific callbacks. - - Returns: - List of callbacks to apply. - """ - return [ - ModelCheckpoint(monitor=self.monitor, mode=self.mode), - EarlyStopping(monitor=self.monitor, mode=self.mode), - ] - def configure_losses(self) -> None: """Initialize the loss criterion.""" diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 5d443b1c791..21a94b4be3d 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -190,6 +190,7 @@ def validation_step( if ( batch_idx < 10 and hasattr(self.trainer, "datamodule") + and hasattr(self.trainer.datamodule, "plot") and self.logger and hasattr(self.logger, "experiment") and hasattr(self.logger.experiment, "add_figure") @@ -313,6 +314,7 @@ def validation_step( if ( batch_idx < 10 and hasattr(self.trainer, "datamodule") + and hasattr(self.trainer.datamodule, "plot") and self.logger and hasattr(self.logger, "experiment") and hasattr(self.logger.experiment, "add_figure") diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 840f2b6c36f..127c7bfc798 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -259,6 +259,7 @@ def validation_step( if ( batch_idx < 10 and hasattr(self.trainer, "datamodule") + and hasattr(self.trainer.datamodule, "plot") and self.logger and hasattr(self.logger, "experiment") and hasattr(self.logger.experiment, "add_figure") diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 2d142dacb19..b2847556620 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -185,6 +185,7 @@ def validation_step( if ( batch_idx < 10 and hasattr(self.trainer, "datamodule") + and hasattr(self.trainer.datamodule, "plot") and self.logger and hasattr(self.logger, "experiment") and hasattr(self.logger.experiment, "add_figure") diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index a0e03082545..9a67d051c9a 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -245,6 +245,7 @@ def validation_step( if ( batch_idx < 10 and hasattr(self.trainer, "datamodule") + and hasattr(self.trainer.datamodule, "plot") and self.logger and hasattr(self.logger, "experiment") and hasattr(self.logger.experiment, "add_figure") diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index ee332053d0d..d5c7cd97d39 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -42,7 +42,7 @@ def __init__( keys: list[str] = [] for key in data_keys: - if key == "image": + if key.startswith("image"): keys.append("input") elif key == "boxes": keys.append("bbox") From ac40ec18e03181a0858bdab61856dbec35c39a43 Mon Sep 17 00:00:00 2001 From: cookie-kyu Date: Tue, 14 Nov 2023 00:50:47 -0600 Subject: [PATCH 04/72] Updated data.py --- tests/data/.DS_Store | Bin 0 -> 6148 bytes .../SouthAmericaSoybean.zip | Bin 0 -> 1781 bytes .../South_America_Soybean_2002.tif | Bin 0 -> 806 bytes .../South_America_Soybean_2021.tif | Bin 0 -> 805 bytes tests/data/south_america_soybean/data.py | 28 ++++++++++++++++-- 5 files changed, 25 insertions(+), 3 deletions(-) create mode 100644 tests/data/.DS_Store create mode 100644 tests/data/south_america_soybean/SouthAmericaSoybean.zip create mode 100644 tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2002.tif create mode 100644 tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2021.tif diff --git a/tests/data/.DS_Store b/tests/data/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..e0546734601442efa0420264f6f27155793be465 GIT binary patch literal 6148 zcmeHKyG{c^3>-s>NHl4XDEAlmgHsf~ARiFILx2Q{6bKUCb^JEt2Z+-JN(wZV?Ai6& zUhNd;GXR_KZ_a@kfGO1x?{3E8=dL3;DvuGVtZ|PO)_B4#cBAOS4$nQsGq%{{Wsv`X zYt~xD*&D7{bB)dGa<_ThZtn8N!zJYv17bi7hygJm27Y0{doN{rk|-+%#DEz1V!*!- zjq2DH_KET7potNHIH5a)>zE~o%^k$9uur6iVwpFA@QDvT4HJ`-t*x6yn5Adj zUw-D~#)WeyHqLF7_F;}a&n)ym{yfXlnFmeGnYm|N3rTy-xH35*f&Fq1k4(a>nT^6e z#)burnPO~t24}enW1818&#-&IwovlS!jdD5CI3sGO!Ju>8^kiHJVH$9;fooQr$mLU zc`#@G^qF&JEPfCb!KEVG9VDE6=x(#9j!%|^RDrP_g4OIZV=_mxuetT$0e2wtdQ*x9-cM&$Fi2e?IndwXgNU&>+d*FN&T$TTps^@$_GBx5P_tax6J!ws>-z z_9T1O?;mum z9=!kUrB5r@l}1nbANgym&i#{*VyE}ypPpGDx!ywltMZlVtSdkNwToZ2`*+TGz5VVl zXZcU>TECf9`}(ihu>uSXs44TQf#7OSV9Hb_BV`&HQjju*M{dfLkoaQ1p}^ozzRi>m ziEUi+=KjAQw(!d8CB5>A?c-KkcOxlX_gb}0;h`?m%r!N^Ho|%#GMbx2=NBD0n2_!j zvg>5Zd37f3$gWSm>Fg^+@9;e3zLNR9YRk*#6H7aly6?)fdY-&vWIn$nXlu$d`Mu|6 z?3_NuEYkO7w8XwIYffF6U7n^pZO-%6niadh-ts*-Z^9C%b=u3nap&$@k-u(l9X;`0?nwMy;&KlxjJe^|~Ige$vciFY0sqODP|EH}hHM>#td}sEr!0NL<{!NqbuQD|)zqv1OxwK*H z!J?%9yBC+$+}AMp{^H4-w%$B0E?~lBWD;S-qGR53%@tUxv!lnv4&$jAcLE5pRVAc`a| z24yP%)rdpYZ~-;QAgQs1vVmF|f|1yK%{&auK(Sju+|8GXJlY!;0Dt`*YkkX z{jX<&(#&9*n}G#PYbvwxfYdWEu!Ctv1`aUI%)kkznHadhGz$Yam?N9Xdz=K%Gd0aET~c9rpeQMG{wN39XXK*hiN$ zHwUe|9=+6Wr~TTH4Wcu)8J#p-k^CvRhiBRG`W4CNUV&NlPT-1Jm+ z=AKr)WVuWCRZC~Oxpl5PoBC~+mQQu9+oJr>(ywLtP8RRFWPDrBgv)rv`KTYE-czH_ e_dmRt_VIU~dHva?Z+=+TeVr{>xz^w8f&c*7fWt8W literal 0 HcmV?d00001 diff --git a/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2021.tif b/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2021.tif new file mode 100644 index 0000000000000000000000000000000000000000..85bef390dfee0750e9f6f9dd9a035ab957f1c8ec GIT binary patch literal 805 zcmebD)MDUZU|-qGR53%@tUxv!lnv4&$jAcLE5pRVAc`a| z24yP%)rdpYZ~`^RAgQs1vVmF|f|1yK%{&auK(Sju+|8GXJlY!;0Dt`*YkkX z{jX<&(#&9*n}G#PYbvwxfYdWEu!Ctv1`aUI%)kkznHadhGz$YamaWL4K5|wT;+4G6ON|vI0)7b%P1vf+a0(IKG9&P%0WL=`jV)sWu zQZl#9%_p5yEt&a8uQO)uC8bSPQ~DN$FFmU=_2!*fizVmzY6{+Ny>?TuCnTwm>&mJ9 zLYpq?TWGCax3{7{=&iXH?{;DNCHL1`NQFwB&p&zp-M21YeO=89F$P~&PQ5JMbnxQq zt0&JjJ(u5^n7)(qo!joT8^P+5opo{9p?lK5nQNU5w%mE{nts$?Grqs&$D>|N&0V&y d+v!G3VX&Oo|L&+A-!}@(JvuqNVWXRZ0suBhz@`8I literal 0 HcmV?d00001 diff --git a/tests/data/south_america_soybean/data.py b/tests/data/south_america_soybean/data.py index ef7fc2ee6b0..9e22df5541b 100644 --- a/tests/data/south_america_soybean/data.py +++ b/tests/data/south_america_soybean/data.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - +#os.environ['PROJ_LIB'] = r'E:\Programs\anaconda3\envs\gis\Library\share\proj' import hashlib import os import shutil @@ -13,6 +13,28 @@ from rasterio.transform import Affine SIZE = 32 +wkt = """ +PROJCS["Albers Conical Equal Area", + GEOGCS["WGS 84", + DATUM["WGS_1984", + SPHEROID["WGS 84",6378137,298.257223563, + AUTHORITY["EPSG","7030"]], + AUTHORITY["EPSG","6326"]], + PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]], + UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]], + AUTHORITY["EPSG","4326"]], + PROJECTION["Albers_Conic_Equal_Area"], + PARAMETER["latitude_of_center",23], + PARAMETER["longitude_of_center",-96], + PARAMETER["standard_parallel_1",29.5], + PARAMETER["standard_parallel_2",45.5], + PARAMETER["false_easting",0], + PARAMETER["false_northing",0], + UNIT["meters",1], + AXIS["Easting",EAST], + AXIS["Northing",NORTH]] +""" + np.random.seed(0) files = ["South_America_Soybean_2002.tif", "South_America_Soybean_2021.tif"] @@ -23,7 +45,7 @@ def create_file(path: str, dtype: str): "driver": "GTiff", "dtype": dtype, "count": 1, - "crs": CRS.from_epsg(4326), + #"crs": CRS.from_wkt(wkt), "transform": Affine( 0.0002499999999999943131, 0.0, @@ -47,7 +69,7 @@ def create_file(path: str, dtype: str): if __name__ == "__main__": dir = os.path.join(os.getcwd(), "SouthAmericaSoybean") - + print(dir) if os.path.exists(dir) and os.path.isdir(dir): shutil.rmtree(dir) From fd0101179a0128672b5674115d2df6d1e42cc816 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:48:42 -0800 Subject: [PATCH 05/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 19f10311261..edb664cbda6 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -12,7 +12,7 @@ from .utils import BoundingBox, download_url -class south_america_soybean(RasterDataset): +class SouthAmericaSoybean(RasterDataset): """South America Soybean Dataset Link: https://www.nature.com/articles/s41893-021-00729-z From 46db62b4c1fc4b65b43da474ced37edfa65ac3ab Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:48:54 -0800 Subject: [PATCH 06/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 50be0b291f2..565703ecbc0 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -10,7 +10,7 @@ from rasterio.crs import CRS import torchgeo.datasets.utils -from torchgeo.datasets import south_america_soybean, BoundingBox, IntersectionDataset, UnionDataset +from torchgeo.datasets import SouthAmericaSoybean, BoundingBox, IntersectionDataset, UnionDataset def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: From f53d822bc0426d52b245504da9a72df16e0c88c5 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:49:01 -0800 Subject: [PATCH 07/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 565703ecbc0..c16bffea35b 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -19,7 +19,7 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestSouthAmericaSoybean: @pytest.fixture - def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> south_america_soybean: + def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SouthAmericaSoybean: monkeypatch.setattr(torchgeo.datasets.southamerica_soybean, "download_url", download_url) transforms = nn.Identity() md5s = { From 1ed0c28387a5bf3d6aa0c56e16e4cda33382f99d Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:49:12 -0800 Subject: [PATCH 08/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index c16bffea35b..2ddecaf5945 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -32,7 +32,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SouthAmericaSoybe monkeypatch.setattr(south_america_soybean, "url", url) - return south_america_soybean( + return SouthAmericaSoybean( transforms=transforms, download=True, checksum=True, From 7a8e0cff5a619053be6f85cad5040c57d59ffbc1 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:49:20 -0800 Subject: [PATCH 09/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 2ddecaf5945..7af67b771e1 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -74,7 +74,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: def test_invalid_year(self, tmp_path: Path) -> None: with pytest.raises( AssertionError, - match="south_america_soybean data product only exists for the following years:", + match="SouthAmericaSoybean data product only exists for the following years:", ): south_america_soybean(str(tmp_path), years=[1996]) From 6b8b143d65449ee02094a03fb47abb9768b3ed2a Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:49:28 -0800 Subject: [PATCH 10/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 7af67b771e1..c057e35e5e3 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -100,7 +100,7 @@ def test_plot_prediction(self, dataset: south_america_soybean) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found"): - south_america_soybean(str(tmp_path)) + SouthAmericaSoybean(str(tmp_path)) def test_invalid_query(self, dataset: south_america_soybean) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) From 38bc7cbce006b2d1d20bb506967e9212008088c9 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:49:41 -0800 Subject: [PATCH 11/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index c057e35e5e3..e30ed51217f 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -20,7 +20,7 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestSouthAmericaSoybean: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SouthAmericaSoybean: - monkeypatch.setattr(torchgeo.datasets.southamerica_soybean, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.south_america_soybean, "download_url", download_url) transforms = nn.Identity() md5s = { 2002: "8a4a9dcea54b3ec7de07657b9f2c0893", From 95f53684d045e4f391440cc2de117b06ea0690af Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:49:50 -0800 Subject: [PATCH 12/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index e30ed51217f..c879f14f730 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -65,7 +65,7 @@ def test_already_extracted(self, dataset: south_america_soybean) -> None: south_america_soybean(dataset.paths, download=True, years=[2021]) def test_already_downloaded(self, tmp_path: Path) -> None: - pathname = os.path.join("tests", "data", "southamerica_soybean", "SouthAmerica_Soybean_2021.tif") + pathname = os.path.join("tests", "data", "south_america_soybean", "SouthAmerica_Soybean_2021.tif") root = str(tmp_path) shutil.copy(pathname, root) From b6c9efc6cd155a572b240ea18154ec8d19c6cdd5 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:50:05 -0800 Subject: [PATCH 13/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index edb664cbda6..4fc23edda66 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -171,7 +171,6 @@ def _verify(self) -> None: # Download the dataset self._download() - self._extract() def _download(self) -> None: """Download the dataset.""" for i in range(21): From 5bf910b7a7f2a62ef69ae121397e55850950f891 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:50:13 -0800 Subject: [PATCH 14/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 4fc23edda66..79daec19829 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -212,7 +212,7 @@ def plot( axs[0, 0].set_title("Mask") if showing_predictions: - axs[0, 1].imshow(self.ordinal_cmap[pred], interpolation="none") + axs[0, 1].imshow(pred, interpolation="none") axs[0, 1].axis("off") if show_titles: axs[0, 1].set_title("Prediction") From 2c2cc42087b5cde44ae1f33406ec5c3191d4a1a2 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:50:25 -0800 Subject: [PATCH 15/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 79daec19829..2293d0f239e 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -205,7 +205,7 @@ def plot( nrows=1, ncols=ncols, figsize=(ncols * 4, 4), squeeze=False ) - axs[0, 0].imshow(self.ordinal_cmap[mask], interpolation="none") + axs[0, 0].imshow(mask, interpolation="none") axs[0, 0].axis("off") if show_titles: From 3eb2e33855cf47944b817083ea2314cb452dcf64 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:50:35 -0800 Subject: [PATCH 16/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 2293d0f239e..84d26175c16 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -163,11 +163,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.paths}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) + raise DatasetNotFoundError(self) # Download the dataset self._download() From 355d8f16b4b157ebbba1bceaac76923554b0f182 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:50:50 -0800 Subject: [PATCH 17/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 84d26175c16..d1bbad51f85 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -152,15 +152,6 @@ def _verify(self) -> None: assert isinstance(self.paths, str) - #todo - pathname = os.path.join(self.paths, "**", self.zipfile_glob) - if glob.glob(pathname, recursive=True): - exists = True - self._extract() - - if exists == True: - return - # Check if the user requested to download the dataset if not self.download: raise DatasetNotFoundError(self) From 430674ad43cb7338878d5686fc7d9d5eefa3ea7a Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:51:03 -0800 Subject: [PATCH 18/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index d1bbad51f85..580df3431cd 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -18,6 +18,7 @@ class SouthAmericaSoybean(RasterDataset): Link: https://www.nature.com/articles/s41893-021-00729-z Dataset contains 1 classes: + 0: nodata 1: soybean Dataset Format: From 5cc76e74de808740be58e06112178632cb1176a8 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:51:09 -0800 Subject: [PATCH 19/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 580df3431cd..4735a4a231a 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -17,7 +17,7 @@ class SouthAmericaSoybean(RasterDataset): Link: https://www.nature.com/articles/s41893-021-00729-z - Dataset contains 1 classes: + Dataset contains 2 classes: 0: nodata 1: soybean From 6cc7e7cec48659a94c172f829551a81443a3e664 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:51:16 -0800 Subject: [PATCH 20/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 4735a4a231a..5e120a0abaa 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -25,7 +25,10 @@ class SouthAmericaSoybean(RasterDataset): 1) 21 .tif files If you use this dataset in your research, please use the corresponding citation: - Song, XP., Hansen, M.C., Potapov, P. et al. Massive soybean expansion in South America since 2000 and implications for conservation. Nat Sustain 4, 784–792 (2021). https://doi.org/10.1038/s41893-021-00729-z + +* https://doi.org/10.1038/s41893-021-00729-z + +.. versionadded:: 0.6 """ filename_glob = "SouthAmerica_Soybean_*.tif" From 4f41e13494043e8bd56c603aff48945c44dce52a Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:51:25 -0800 Subject: [PATCH 21/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 5e120a0abaa..8df64cc9e44 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -15,6 +15,8 @@ class SouthAmericaSoybean(RasterDataset): """South America Soybean Dataset +This dataset produced annual 30-m soybean maps of South America from 2001 to 2021. + Link: https://www.nature.com/articles/s41893-021-00729-z Dataset contains 2 classes: From 45b306e45c604eedaeae0d481a7a6177ebfc7055 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:51:31 -0800 Subject: [PATCH 22/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 8df64cc9e44..6c288d4fd70 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -9,7 +9,7 @@ from rasterio.crs import CRS from .geo import RasterDataset -from .utils import BoundingBox, download_url +from .utils import BoundingBox, DatasetNotFoundError, download_url class SouthAmericaSoybean(RasterDataset): From 558cf44aca23b567bf85a1575df78b639b3f4a6c Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:51:39 -0800 Subject: [PATCH 23/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index c879f14f730..1693c9fc42c 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -102,7 +102,7 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found"): SouthAmericaSoybean(str(tmp_path)) - def test_invalid_query(self, dataset: south_america_soybean) -> None: + def test_invalid_query(self, dataset: SouthAmericaSoybean) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( IndexError, match="query: .* not found in index with bounds:" From 4c7a5aabbdef89d0e7fb00d8d7555f15a347ad61 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:51:59 -0800 Subject: [PATCH 24/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 1693c9fc42c..dc21cafc06e 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -39,7 +39,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SouthAmericaSoybe years=[2002, 2021], ) - def test_getitem(self, dataset: south_america_soybean) -> None: + def test_getitem(self, dataset: SouthAmericaSoybean) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) assert isinstance(x["crs"], CRS) From 48f1b63913e6544a709581bfbda63aa4cff21428 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:52:05 -0800 Subject: [PATCH 25/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index dc21cafc06e..9ecb27cab7e 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -80,7 +80,7 @@ def test_invalid_year(self, tmp_path: Path) -> None: def test_invalid_classes(self) -> None: with pytest.raises(AssertionError): - south_america_soybean(classes=[-1]) + SouthAmericaSoybean(classes=[-1]) with pytest.raises(AssertionError): south_america_soybean(classes=[11]) From 281fda2c3c97273c5448644b83e97426f8a4e683 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:52:11 -0800 Subject: [PATCH 26/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 9ecb27cab7e..df7b2167f62 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -83,7 +83,7 @@ def test_invalid_classes(self) -> None: SouthAmericaSoybean(classes=[-1]) with pytest.raises(AssertionError): - south_america_soybean(classes=[11]) + SouthAmericaSoybean(classes=[11]) def test_plot(self, dataset: south_america_soybean) -> None: query = dataset.bounds From 12d1ea3399d2e7bf620fa69ba63bb4b1518fa274 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:52:47 -0800 Subject: [PATCH 27/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 6c288d4fd70..c96f2716fa7 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -108,11 +108,6 @@ def __init__( "South America Soybean data only exists for the following years: " f"{list(self.md5s.keys())}." ) - assert ( - set(classes) <= self.cmap.keys() - ), f"Only the following classes are valid: {list(self.cmap.keys())}." - assert 0 in classes, "Classes must include the background class: 0" - self.years = years self.paths = paths From bed3b0310788fa895bb8e821f8cdef1037375de3 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:53:03 -0800 Subject: [PATCH 28/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index df7b2167f62..56f9faf35b1 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -46,7 +46,7 @@ def test_getitem(self, dataset: SouthAmericaSoybean) -> None: assert isinstance(x["mask"], torch.Tensor) def test_classes(self) -> None: - root = os.path.join("tests", "data", "southamerica_soybean") + root = os.path.join("tests", "data", "south_america_soybean") classes = list(south_america_soybean.cmap.keys())[0:2] ds = south_america_soybean(root, years=[2021], classes=classes) sample = ds[ds.bounds] From c66359b2760a4d4f97b7dd3c826a68880acc0c9c Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:53:08 -0800 Subject: [PATCH 29/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 56f9faf35b1..398cf734ac1 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -91,7 +91,7 @@ def test_plot(self, dataset: south_america_soybean) -> None: dataset.plot(x, suptitle="Test") plt.close() - def test_plot_prediction(self, dataset: south_america_soybean) -> None: + def test_plot_prediction(self, dataset: SouthAmericaSoybean) -> None: query = dataset.bounds x = dataset[query] x["prediction"] = x["mask"].clone() From 3886f0b0b4c02f2488f624a5f50e1f7d7f9e99d1 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:53:20 -0800 Subject: [PATCH 30/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 398cf734ac1..156745ccbb4 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -85,7 +85,7 @@ def test_invalid_classes(self) -> None: with pytest.raises(AssertionError): SouthAmericaSoybean(classes=[11]) - def test_plot(self, dataset: south_america_soybean) -> None: + def test_plot(self, dataset: SouthAmericaSoybean) -> None: query = dataset.bounds x = dataset[query] dataset.plot(x, suptitle="Test") From 824ce88153d2314396826f3ab67fea030c3eafc6 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:53:43 -0800 Subject: [PATCH 31/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index c96f2716fa7..b51a4afaa93 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -24,7 +24,9 @@ class SouthAmericaSoybean(RasterDataset): 1: soybean Dataset Format: - 1) 21 .tif files + + * 21 .tif files + If you use this dataset in your research, please use the corresponding citation: From 557b16975e40ad1113ba627352a2672c9741c31c Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:53:58 -0800 Subject: [PATCH 32/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index b51a4afaa93..1a1d153c85e 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -38,7 +38,6 @@ class SouthAmericaSoybean(RasterDataset): filename_glob = "SouthAmerica_Soybean_*.tif" filename_regex = (r"SouthAmerica_Soybean_(?P\d{4})\.tif") - zipfile_glob = "" date_format = "%Y" is_image = False From 55c1720c8d2cf90f4f8f409dd7c1e178147fdc4e Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:54:12 -0800 Subject: [PATCH 33/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 1a1d153c85e..59f8d1a0f49 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -42,7 +42,7 @@ class SouthAmericaSoybean(RasterDataset): date_format = "%Y" is_image = False - url = "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2001.tif" + url = "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_" md5s = { 2001: "2914b0af7590a0ca4dfa9ccefc99020f", From 674997468562451558b54bf6f2e30324f3b6f983 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:54:48 -0800 Subject: [PATCH 34/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 59f8d1a0f49..e677be18c11 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -69,11 +69,6 @@ class SouthAmericaSoybean(RasterDataset): } - - cmap = { - 0: (0,0,0,0), - 1: (255,0,255,255) - } def __init__( self, From d1e7a019db2f30e90b86dbd9d21bd492f8076bec Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:55:13 -0800 Subject: [PATCH 35/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index e677be18c11..02481144e6f 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -76,7 +76,6 @@ def __init__( crs: Optional[CRS] = None, res: Optional[float] = None, years: list[int] = [2021], - classes: list[int] = list(cmap.keys()), transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, download: bool = False, From e7886241c9bd97c8ebf6db7d057c8a839dbb3b11 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:55:37 -0800 Subject: [PATCH 36/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 02481144e6f..248b812eeda 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -106,7 +106,6 @@ def __init__( self.years = years self.paths = paths - self.classes = classes self.download = download self.checksum = checksum self.ordinal_map = torch.zeros(max(self.cmap.keys()) + 1, dtype=self.dtype) From 15809645833d8df702b5755083fd6dc7122e491a Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:55:54 -0800 Subject: [PATCH 37/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 248b812eeda..74ff4cdb152 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -108,8 +108,6 @@ def __init__( self.paths = paths self.download = download self.checksum = checksum - self.ordinal_map = torch.zeros(max(self.cmap.keys()) + 1, dtype=self.dtype) - self.ordinal_cmap = torch.zeros((len(self.classes), 4), dtype=torch.uint8) self._verify() From a5262cef32c5b892cc21326ababf325be73f3536 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:56:11 -0800 Subject: [PATCH 38/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 156745ccbb4..8404da2d98a 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -47,8 +47,8 @@ def test_getitem(self, dataset: SouthAmericaSoybean) -> None: def test_classes(self) -> None: root = os.path.join("tests", "data", "south_america_soybean") - classes = list(south_america_soybean.cmap.keys())[0:2] - ds = south_america_soybean(root, years=[2021], classes=classes) + classes = list(SouthAmericaSoybean.cmap.keys())[0:2] + ds = SouthAmericaSoybean(root, years=[2021], classes=classes) sample = ds[ds.bounds] mask = sample["mask"] assert mask.max() < len(classes) From d8ca8d67a30964402d0f83ced20d25650d5ae40e Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:56:38 -0800 Subject: [PATCH 39/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 8404da2d98a..2fc095abc0d 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -53,7 +53,7 @@ def test_classes(self) -> None: mask = sample["mask"] assert mask.max() < len(classes) - def test_and(self, dataset: south_america_soybean) -> None: + def test_and(self, dataset: SouthAmericaSoybean) -> None: ds = dataset & dataset assert isinstance(ds, IntersectionDataset) From 0a256a3194a5712602b307684f708c2eaf12c20f Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:58:47 -0800 Subject: [PATCH 40/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 2fc095abc0d..04dd8ff7d3d 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -57,7 +57,7 @@ def test_and(self, dataset: SouthAmericaSoybean) -> None: ds = dataset & dataset assert isinstance(ds, IntersectionDataset) - def test_or(self, dataset: south_america_soybean) -> None: + def test_or(self, dataset: SouthAmericaSoybean) -> None: ds = dataset | dataset assert isinstance(ds, UnionDataset) From c0070edd21dd8ed6d3dfb5f7fb3e7c1008c3073e Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:59:05 -0800 Subject: [PATCH 41/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 74ff4cdb152..bdda72f8db1 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -112,11 +112,6 @@ def __init__( self._verify() super().__init__(paths, crs, res, transforms=transforms, cache=cache) - - for v, k in enumerate(self.classes): - self.ordinal_map[k] = v - self.ordinal_cmap[v] = torch.tensor(self.cmap[k]) - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve mask and metadata indexed by query. From 1fe3a16c6d3553d51ce478fc73108baa8a50e44d Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:59:22 -0800 Subject: [PATCH 42/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index bdda72f8db1..7f4f6d25d55 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -123,7 +123,6 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: IndexError: if query is not found in the index """ sample = super().__getitem__(query) - sample["mask"] = self.ordinal_map[sample["mask"]] return sample def _verify(self) -> None: From 8a6d0b05c25a3902f69b9a64f207859ed2eb93f0 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:59:34 -0800 Subject: [PATCH 43/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- torchgeo/datasets/south_america_soybean.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 7f4f6d25d55..c1b2e18a749 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -133,10 +133,6 @@ def _verify(self) -> None: # Check if the extracted files already exist if self.files: return - - # Check if the zip files have already been downloaded - exists = False - assert isinstance(self.paths, str) # Check if the user requested to download the dataset From 74356ee7290bee69125b3a5309ba0958d2815a91 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 20:00:20 -0800 Subject: [PATCH 44/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 04dd8ff7d3d..6a40d1f9b73 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -69,7 +69,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: root = str(tmp_path) shutil.copy(pathname, root) - south_america_soybean(root, years=[2021]) + SouthAmericaSoybean(root, years=[2021]) def test_invalid_year(self, tmp_path: Path) -> None: with pytest.raises( From 6b59cfdded8f89cac90b0cb3c4d5c4d8e884e051 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 20:00:31 -0800 Subject: [PATCH 45/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 6a40d1f9b73..0a5593364f2 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -61,7 +61,7 @@ def test_or(self, dataset: SouthAmericaSoybean) -> None: ds = dataset | dataset assert isinstance(ds, UnionDataset) - def test_already_extracted(self, dataset: south_america_soybean) -> None: + def test_already_extracted(self, dataset: SouthAmericaSoybean) -> None: south_america_soybean(dataset.paths, download=True, years=[2021]) def test_already_downloaded(self, tmp_path: Path) -> None: From 61579dd9f025b26755f3a618e853a8c5959796b1 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 17 Nov 2023 20:00:47 -0800 Subject: [PATCH 46/72] Update tests/datasets/test_south_america_soybean.py Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- tests/datasets/test_south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 0a5593364f2..0e3edcc5afb 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -62,7 +62,7 @@ def test_or(self, dataset: SouthAmericaSoybean) -> None: assert isinstance(ds, UnionDataset) def test_already_extracted(self, dataset: SouthAmericaSoybean) -> None: - south_america_soybean(dataset.paths, download=True, years=[2021]) + SouthAmericaSoybean(dataset.paths, download=True, years=[2021]) def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join("tests", "data", "south_america_soybean", "SouthAmerica_Soybean_2021.tif") From e67613fb1c8c3041dae93712e4ffb03093eb15f3 Mon Sep 17 00:00:00 2001 From: cookie-kyu Date: Fri, 1 Dec 2023 10:53:50 -0600 Subject: [PATCH 47/72] Updated tests --- tests/datasets/test_south_america_soybean.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 0e3edcc5afb..da6e807d783 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -26,10 +26,10 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SouthAmericaSoybe 2002: "8a4a9dcea54b3ec7de07657b9f2c0893", 2021: "edff3ada13a1a9910d1fe844d28ae4f", } - monkeypatch.setattr(south_america_soybean, "md5s", md5s) + monkeypatch.setattr(SouthAmericaSoybean, "md5s", md5s) url = os.path.join("tests", "data", "south_america_soybean", "SouthAmerica_Soybean_{}.tif") - monkeypatch.setattr(south_america_soybean, "url", url) + monkeypatch.setattr(SouthAmericaSoybean, "url", url) return SouthAmericaSoybean( @@ -76,7 +76,7 @@ def test_invalid_year(self, tmp_path: Path) -> None: AssertionError, match="SouthAmericaSoybean data product only exists for the following years:", ): - south_america_soybean(str(tmp_path), years=[1996]) + SouthAmericaSoybean(str(tmp_path), years=[1996]) def test_invalid_classes(self) -> None: with pytest.raises(AssertionError): From 97ba0fce050eaae9514abacaade6bf83bbe1f7f3 Mon Sep 17 00:00:00 2001 From: cookie-kyu Date: Wed, 3 Jan 2024 21:45:15 -0800 Subject: [PATCH 48/72] fixed an error in init --- tests/data/.DS_Store | Bin 6148 -> 6148 bytes .../South_America_Soybean_2002.tif | Bin 806 -> 0 bytes .../South_America_Soybean_2021.tif | Bin 805 -> 0 bytes tests/data/south_america_soybean/data.py | 2 +- torchgeo/datasets/__init__.py | 2 +- 5 files changed, 2 insertions(+), 2 deletions(-) delete mode 100644 tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2002.tif delete mode 100644 tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2021.tif diff --git a/tests/data/.DS_Store b/tests/data/.DS_Store index e0546734601442efa0420264f6f27155793be465..179e7aa16ad09e16257778afe4e84406ec02afc8 100644 GIT binary patch delta 272 zcmZoMXfc=|#>B)qu~2NHo+2aD!~pA!4;mOJ8;Gz>?ANYODlaZb%E?b+U|`spRFIQd zTw-8wjgg6&g_Vt+gPnt$BQ`iAzdX1kv81%vDX}OT#0$yK&q;!@6O+O+Q_JH8M4a>U zN)j{kQj5SEGE-84N@Bt@^HTE5o$^cbQi{QPgCPCJ*u~2NHo+2aT!~knX#?AkkI9N9aFc&jzX6NAN07`FmWd6=PnP0?^ Pkzuk8kM!mkkrm7U^-qGR53%@tUxv!lnv4&$jAcLE5pRVAc`a| z24yP%)rdpYZ~-;QAgQs1vVmF|f|1yK%{&auK(Sju+|8GXJlY!;0Dt`*YkkX z{jX<&(#&9*n}G#PYbvwxfYdWEu!Ctv1`aUI%)kkznHadhGz$Yam?N9Xdz=K%Gd0aET~c9rpeQMG{wN39XXK*hiN$ zHwUe|9=+6Wr~TTH4Wcu)8J#p-k^CvRhiBRG`W4CNUV&NlPT-1Jm+ z=AKr)WVuWCRZC~Oxpl5PoBC~+mQQu9+oJr>(ywLtP8RRFWPDrBgv)rv`KTYE-czH_ e_dmRt_VIU~dHva?Z+=+TeVr{>xz^w8f&c*7fWt8W diff --git a/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2021.tif b/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2021.tif deleted file mode 100644 index 85bef390dfee0750e9f6f9dd9a035ab957f1c8ec..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 805 zcmebD)MDUZU|-qGR53%@tUxv!lnv4&$jAcLE5pRVAc`a| z24yP%)rdpYZ~`^RAgQs1vVmF|f|1yK%{&auK(Sju+|8GXJlY!;0Dt`*YkkX z{jX<&(#&9*n}G#PYbvwxfYdWEu!Ctv1`aUI%)kkznHadhGz$YamaWL4K5|wT;+4G6ON|vI0)7b%P1vf+a0(IKG9&P%0WL=`jV)sWu zQZl#9%_p5yEt&a8uQO)uC8bSPQ~DN$FFmU=_2!*fizVmzY6{+Ny>?TuCnTwm>&mJ9 zLYpq?TWGCax3{7{=&iXH?{;DNCHL1`NQFwB&p&zp-M21YeO=89F$P~&PQ5JMbnxQq zt0&JjJ(u5^n7)(qo!joT8^P+5opo{9p?lK5nQNU5w%mE{nts$?Grqs&$D>|N&0V&y d+v!G3VX&Oo|L&+A-!}@(JvuqNVWXRZ0suBhz@`8I diff --git a/tests/data/south_america_soybean/data.py b/tests/data/south_america_soybean/data.py index 9e22df5541b..8c7c3afe2c7 100644 --- a/tests/data/south_america_soybean/data.py +++ b/tests/data/south_america_soybean/data.py @@ -45,7 +45,7 @@ def create_file(path: str, dtype: str): "driver": "GTiff", "dtype": dtype, "count": 1, - #"crs": CRS.from_wkt(wkt), + "crs": CRS.from_epsg(4326), "transform": Affine( 0.0002499999999999943131, 0.0, diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index f6cddcc80a5..370558aa239 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -176,7 +176,7 @@ "Sentinel", "Sentinel1", "Sentinel2", - "SouthAmericaSoybean" + "SouthAmericaSoybean", # NonGeoDataset "ADVANCE", "BeninSmallHolderCashews", From 305ea4421058499b806a13db548047ae007b4efe Mon Sep 17 00:00:00 2001 From: cookie-kyu Date: Sun, 14 Jan 2024 21:10:03 -0800 Subject: [PATCH 49/72] fixed some path inconsistencies --- .../SouthAmericaSoybean.zip | Bin 1781 -> 1578 bytes .../South_America_Soybean_2002.tif | Bin 0 -> 644 bytes .../South_America_Soybean_2021.tif | Bin 0 -> 643 bytes tests/datasets/test_south_america_soybean.py | 52 ++++++------------ torchgeo/datasets/south_america_soybean.py | 36 +++++++++--- 5 files changed, 43 insertions(+), 45 deletions(-) create mode 100644 tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2002.tif create mode 100644 tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2021.tif diff --git a/tests/data/south_america_soybean/SouthAmericaSoybean.zip b/tests/data/south_america_soybean/SouthAmericaSoybean.zip index fb060dad71566dfc2ab3f24fe6a41c85f0723e92..5453b89fc25128d3e30c0281ee069bf1e8773aa7 100644 GIT binary patch literal 1578 zcmWIWW@Zs#0D)zjR3g9(C=mdpLyJ?3iuJR~^+WQDk~7>>^Yu$WjC2qqwYa2MKP9mw zQNK99v?L=wF*mg+GdVH7IKMI}H8D>=7$oWl5e$Zk1wd`%0NWO~$II^tBLhPV69a=0 z&^By3z+&-GW8$Ia#2XnH80nQ{rhWc$CFwyzLV`d_LPAPf0?Q_mh8cpLXIWC4Lkyj7 zu5mcSmfE~QL798TIwqk-9*>edo^?21Enwu7R59HaU?wcUuDtQ*tV)R)jbC&X^0FQ# zCNN)`-%wF&ZG6KAk|JsrLsD^uGb-R;1}aPY}gdHtVik~a=#urV$8v~-rD zyYq=g`-HE21z&p-#5WonY-qUQ!_%V^V8|!%*Qa4(ac;MLCOD~pexaf4u zZ_4_cnvnsEkN8er{xhSi@%U^T$Ekehr*w*HCSKbUJCEOf@0ZM?(ANc31rMK3&2IZ7 ztha*m&7;bt`IZ;mzcxOR(=UHm*&fujt7seRFI_$_-k6dZrjC%`}sdW{eEUPcuw|GR-tLI{s(oWe4N{mbCY{u=<*a$;VEaA2ojHbabLp^hZs18J@CbE=y?cc@MJu|C4 zGbdl2vaj`__xV*mdn?ngC7rC?H>+uT^!mFZZ#X@*p6n{h+J9kL_b;)i?CWpWzTmIQ zUUzhNY((A3zhA|I76v{4?0>KR``YgC`0&q8JqM>x4c7NIm)-N@*Q=AozUQBot1Wx~ zq~`fH>$h$jcBQ1p#+1F7_a*;XRC=>+`MllB;(ot#c{g4xTfhKJGmK0k%(%-nND(K%@YWH;qJAj{vmJXm2eF-jVM${i rkV$%}hwvT5eDpklFniev$BEAUFA@QDvT4HJ`-t*x6yn5Adj zUw-D~#)WeyHqLF7_F;}a&n)ym{yfXlnFmeGnYm|N3rTy-xH35*f&Fq1k4(a>nT^6e z#)burnPO~t24}enW1818&#-&IwovlS!jdD5CI3sGO!Ju>8^kiHJVH$9;fooQr$mLU zc`#@G^qF&JEPfCb!KEVG9VDE6=x(#9j!%|^RDrP_g4OIZV=_mxuetT$0e2wtdQ*x9-cM&$Fi2e?IndwXgNU&>+d*FN&T$TTps^@$_GBx5P_tax6J!ws>-z z_9T1O?;mum z9=!kUrB5r@l}1nbANgym&i#{*VyE}ypPpGDx!ywltMZlVtSdkNwToZ2`*+TGz5VVl zXZcU>TECf9`}(ihu>uSXs44TQf#7OSV9Hb_BV`&HQjju*M{dfLkoaQ1p}^ozzRi>m ziEUi+=KjAQw(!d8CB5>A?c-KkcOxlX_gb}0;h`?m%r!N^Ho|%#GMbx2=NBD0n2_!j zvg>5Zd37f3$gWSm>Fg^+@9;e3zLNR9YRk*#6H7aly6?)fdY-&vWIn$nXlu$d`Mu|6 z?3_NuEYkO7w8XwIYffF6U7n^pZO-%6niadh-ts*-Z^9C%b=u3nap&$@k-u(l9X;`0?nwMy;&KlxjJe^|~Ige$vciFY0sqODP|EH}hHM>#td}sEr!0NL<{!NqbuQD|)zqv1OxwK*H z!J?%9yBC+$+}AMp{^H4-w%$B0E?~lBWD;S-qGR53%@tUxv!lnv4&$jAcLTLt8ZB8iJZ z*$O~4;!rhQKs7Q*YHXowCZL*NBsO0&4+Aq$>=qC=weT>o0qIXbyuO_o>_#P^IUCw} z7(n`xfb5O!OkkfB0ofohdF97TzDWl5Fd8C?o2-1PX2{VHa^MeGpMyP2B_r5p>RyeS6voty>Y_yQE6Xk1mN*2^8ans~!KAK|S z@g`4DrNc92Vns@rHE);a!%r`Lnq1X7oE}{GdF7g+{Dcc?0=sx#>03_pQGFn}$ID%K z-l3GDuAFUwev5-r16OV>QZ`zD@_E6HEAy86Pr9)B+seCny3&i!YhM1%8>g=}?SxiI zW$dHNnVW;wU5{Ssx6^)Y$Oh3F+l)>cu1NkA+{3f%c>N0R7p6PEp8FYg;`CLfg-P}6YolX6=OUtLa)@@P#XX)3nd?$-{T{6BcXToK?;(XMP jQ17Wx=ldUCO#Apd&%FNZ(l%PtwtX%7Fc0m9D8>Xl~ literal 0 HcmV?d00001 diff --git a/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2021.tif b/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2021.tif new file mode 100644 index 0000000000000000000000000000000000000000..a220b500677c0b2b954c9a735358450535cfec60 GIT binary patch literal 643 zcmebD)MDUZU|-qGR53%@tUxv!lnv4&$jAcLTLt8ZB8iJZ z*$O~4;!ri5Ks7Q*YHXowCZL*NBsO0&4+Aq$>=qC=weT>o0qIXbyuO_o>_#P^IUCw} z7(n`xfb5O!OkkfB0ofohdF97TzDWl5Fd8C?o2-1PX2{VHa^MeGpMyP2B_r5p>MYiHojyM?XOo>W2ne6#QU?odYzUk}$wSt?We1SUcUXM2YJhCoP zWU>3BASsz!=H`=5s+P?Bqt_WT_ma{kt0{d8!+}h^QE)^ literal 0 HcmV?d00001 diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index da6e807d783..d6293d0bb92 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -22,21 +22,15 @@ class TestSouthAmericaSoybean: def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SouthAmericaSoybean: monkeypatch.setattr(torchgeo.datasets.south_america_soybean, "download_url", download_url) transforms = nn.Identity() - md5s = { - 2002: "8a4a9dcea54b3ec7de07657b9f2c0893", - 2021: "edff3ada13a1a9910d1fe844d28ae4f", - } - monkeypatch.setattr(SouthAmericaSoybean, "md5s", md5s) - - url = os.path.join("tests", "data", "south_america_soybean", "SouthAmerica_Soybean_{}.tif") + url = os.path.join("tests", "data", "south_america_soybean", "SouthAmericaSoybean.zip") monkeypatch.setattr(SouthAmericaSoybean, "url", url) - - - return SouthAmericaSoybean( + root = str(tmp_path) + #url = os.path.join("tests", "data", "south_america_soybean", "SouthAmerica_Soybean_{}.tif") + return SouthAmericaSoybean(root, transforms=transforms, download=True, checksum=True, - years=[2002, 2021], + ) def test_getitem(self, dataset: SouthAmericaSoybean) -> None: @@ -45,14 +39,6 @@ def test_getitem(self, dataset: SouthAmericaSoybean) -> None: assert isinstance(x["crs"], CRS) assert isinstance(x["mask"], torch.Tensor) - def test_classes(self) -> None: - root = os.path.join("tests", "data", "south_america_soybean") - classes = list(SouthAmericaSoybean.cmap.keys())[0:2] - ds = SouthAmericaSoybean(root, years=[2021], classes=classes) - sample = ds[ds.bounds] - mask = sample["mask"] - assert mask.max() < len(classes) - def test_and(self, dataset: SouthAmericaSoybean) -> None: ds = dataset & dataset assert isinstance(ds, IntersectionDataset) @@ -62,28 +48,22 @@ def test_or(self, dataset: SouthAmericaSoybean) -> None: assert isinstance(ds, UnionDataset) def test_already_extracted(self, dataset: SouthAmericaSoybean) -> None: - SouthAmericaSoybean(dataset.paths, download=True, years=[2021]) + SouthAmericaSoybean(dataset.paths, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: - pathname = os.path.join("tests", "data", "south_america_soybean", "SouthAmerica_Soybean_2021.tif") + pathname = os.path.join("tests","data", "south_america_soybean", "SouthAmericaSoybean.zip") + root = str(tmp_path) shutil.copy(pathname, root) - SouthAmericaSoybean(root, years=[2021]) - - def test_invalid_year(self, tmp_path: Path) -> None: - with pytest.raises( - AssertionError, - match="SouthAmericaSoybean data product only exists for the following years:", - ): - SouthAmericaSoybean(str(tmp_path), years=[1996]) - - def test_invalid_classes(self) -> None: - with pytest.raises(AssertionError): - SouthAmericaSoybean(classes=[-1]) - - with pytest.raises(AssertionError): - SouthAmericaSoybean(classes=[11]) + SouthAmericaSoybean(root) + + # def test_invalid_year(self, tmp_path: Path) -> None: + # with pytest.raises( + # AssertionError, + # match="SouthAmericaSoybean data product only exists for the following years:", + # ): + # SouthAmericaSoybean(str(tmp_path), years=[1996]) def test_plot(self, dataset: SouthAmericaSoybean) -> None: query = dataset.bounds diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index c1b2e18a749..f21b3b70bde 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -9,7 +9,7 @@ from rasterio.crs import CRS from .geo import RasterDataset -from .utils import BoundingBox, DatasetNotFoundError, download_url +from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive class SouthAmericaSoybean(RasterDataset): @@ -35,8 +35,9 @@ class SouthAmericaSoybean(RasterDataset): .. versionadded:: 0.6 """ - filename_glob = "SouthAmerica_Soybean_*.tif" - filename_regex = (r"SouthAmerica_Soybean_(?P\d{4})\.tif") + filename_glob = "SouthAmerica_Soybean_*.*" + filename_regex = (r"SouthAmerica_Soybean_(?P\d{4})\)") + zipfile_glob = "SouthAmericaSoybean.zip" date_format = "%Y" @@ -108,7 +109,7 @@ def __init__( self.paths = paths self.download = download self.checksum = checksum - + print("paths:" , paths) self._verify() super().__init__(paths, crs, res, transforms=transforms, cache=cache) @@ -135,19 +136,36 @@ def _verify(self) -> None: return assert isinstance(self.paths, str) + pathname = os.path.join(self.paths, "**", self.zipfile_glob) + if glob.glob(pathname, recursive=True): + self._extract() + return + + # Check if the user requested to download the dataset if not self.download: raise DatasetNotFoundError(self) # Download the dataset self._download() + self._extract() def _download(self) -> None: """Download the dataset.""" - for i in range(21): - ext = ".tif" - downloadUrl = self.url + str(i+2001) + ext - download_url(downloadUrl,self.paths,md5 = self.md5s if self.checksum else None) - + # for i in range(21): + # ext = ".tif" + # downloadUrl = self.url + str(i+2001) + ext + # download_url(downloadUrl,self.paths,md5 = self.md5s if self.checksum else None) + + filename = "SouthAmericaSoybean.zip" + download_url( + self.url, self.paths, filename, md5s=self.md5s if self.checksum else None + ) + def _extract(self) -> None: + """Extract the dataset.""" + assert isinstance(self.paths, str) + pathname = os.path.join(self.paths, "**", self.zipfile_glob) + extract_archive(glob.glob(pathname, recursive=True)[0], self.paths) + def plot( self, From d3e82ed3769e4ac4e32356e72a6013088491422d Mon Sep 17 00:00:00 2001 From: cookie-kyu Date: Sun, 21 Jan 2024 21:44:42 -0600 Subject: [PATCH 50/72] fixed all errors --- tests/datasets/test_south_america_soybean.py | 22 ++++------------- torchgeo/datasets/south_america_soybean.py | 25 +++++++------------- 2 files changed, 13 insertions(+), 34 deletions(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index d6293d0bb92..29085721371 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -10,13 +10,12 @@ from rasterio.crs import CRS import torchgeo.datasets.utils -from torchgeo.datasets import SouthAmericaSoybean, BoundingBox, IntersectionDataset, UnionDataset +from torchgeo.datasets import SouthAmericaSoybean, BoundingBox, IntersectionDataset, UnionDataset, DatasetNotFoundError def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: shutil.copy(url, root) - class TestSouthAmericaSoybean: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SouthAmericaSoybean: @@ -25,12 +24,10 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SouthAmericaSoybe url = os.path.join("tests", "data", "south_america_soybean", "SouthAmericaSoybean.zip") monkeypatch.setattr(SouthAmericaSoybean, "url", url) root = str(tmp_path) - #url = os.path.join("tests", "data", "south_america_soybean", "SouthAmerica_Soybean_{}.tif") - return SouthAmericaSoybean(root, + return SouthAmericaSoybean(paths=root, transforms=transforms, download=True, - checksum=True, - + checksum=True, ) def test_getitem(self, dataset: SouthAmericaSoybean) -> None: @@ -52,19 +49,10 @@ def test_already_extracted(self, dataset: SouthAmericaSoybean) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join("tests","data", "south_america_soybean", "SouthAmericaSoybean.zip") - root = str(tmp_path) - shutil.copy(pathname, root) SouthAmericaSoybean(root) - - # def test_invalid_year(self, tmp_path: Path) -> None: - # with pytest.raises( - # AssertionError, - # match="SouthAmericaSoybean data product only exists for the following years:", - # ): - # SouthAmericaSoybean(str(tmp_path), years=[1996]) - + def test_plot(self, dataset: SouthAmericaSoybean) -> None: query = dataset.bounds x = dataset[query] @@ -79,7 +67,7 @@ def test_plot_prediction(self, dataset: SouthAmericaSoybean) -> None: plt.close() def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): SouthAmericaSoybean(str(tmp_path)) def test_invalid_query(self, dataset: SouthAmericaSoybean) -> None: diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index f21b3b70bde..188f9f82a8b 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -35,8 +35,8 @@ class SouthAmericaSoybean(RasterDataset): .. versionadded:: 0.6 """ - filename_glob = "SouthAmerica_Soybean_*.*" - filename_regex = (r"SouthAmerica_Soybean_(?P\d{4})\)") + filename_glob = "South_America_Soybean_*.*" + filename_regex = r"South_America_Soybean_(?P\d{4})" zipfile_glob = "SouthAmericaSoybean.zip" @@ -45,6 +45,7 @@ class SouthAmericaSoybean(RasterDataset): url = "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_" + md5 = "7f1d06a57cc6c4ae6be3b3fb9464ddeb" md5s = { 2001: "2914b0af7590a0ca4dfa9ccefc99020f", 2002: "8a4a9dcea54b3ec7de07657b9f2c0893", @@ -76,7 +77,6 @@ def __init__( paths: Union[str, Iterable[str]] = "data", crs: Optional[CRS] = None, res: Optional[float] = None, - years: list[int] = [2021], transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, download: bool = False, @@ -100,16 +100,9 @@ def __init__( RuntimeError: if ``download=False`` but dataset is missing or checksum fails AssertionError: if ``year`` is invalid """ - assert set(years) <= self.md5s.keys(), ( - "South America Soybean data only exists for the following years: " - f"{list(self.md5s.keys())}." - ) - - self.years = years self.paths = paths self.download = download self.checksum = checksum - print("paths:" , paths) self._verify() super().__init__(paths, crs, res, transforms=transforms, cache=cache) @@ -124,6 +117,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: IndexError: if query is not found in the index """ sample = super().__getitem__(query) + return sample def _verify(self) -> None: @@ -135,7 +129,6 @@ def _verify(self) -> None: if self.files: return assert isinstance(self.paths, str) - pathname = os.path.join(self.paths, "**", self.zipfile_glob) if glob.glob(pathname, recursive=True): self._extract() @@ -151,19 +144,17 @@ def _verify(self) -> None: self._extract() def _download(self) -> None: """Download the dataset.""" - # for i in range(21): - # ext = ".tif" - # downloadUrl = self.url + str(i+2001) + ext - # download_url(downloadUrl,self.paths,md5 = self.md5s if self.checksum else None) - filename = "SouthAmericaSoybean.zip" + download_url( - self.url, self.paths, filename, md5s=self.md5s if self.checksum else None + self.url, self.paths, filename, md5=self.md5 if self.checksum else None ) def _extract(self) -> None: """Extract the dataset.""" assert isinstance(self.paths, str) + pathname = os.path.join(self.paths, "**", self.zipfile_glob) + extract_archive(glob.glob(pathname, recursive=True)[0], self.paths) From 4ddaf7087ee63bc7a5f9b41e7e720214f4eddfc5 Mon Sep 17 00:00:00 2001 From: cookie-kyu Date: Mon, 22 Jan 2024 22:31:47 -0600 Subject: [PATCH 51/72] Fix comments --- torchgeo/datasets/south_america_soybean.py | 59 +++++----------------- 1 file changed, 14 insertions(+), 45 deletions(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 188f9f82a8b..93fa2c097a1 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -1,3 +1,8 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""South America Soybean Dataset""" + import glob import os from collections.abc import Iterable @@ -15,7 +20,7 @@ class SouthAmericaSoybean(RasterDataset): """South America Soybean Dataset -This dataset produced annual 30-m soybean maps of South America from 2001 to 2021. + This dataset produced annual 30-m soybean maps of South America from 2001 to 2021. Link: https://www.nature.com/articles/s41893-021-00729-z @@ -30,47 +35,18 @@ class SouthAmericaSoybean(RasterDataset): If you use this dataset in your research, please use the corresponding citation: -* https://doi.org/10.1038/s41893-021-00729-z - -.. versionadded:: 0.6 + * https://doi.org/10.1038/s41893-021-00729-z + .. versionadded:: 0.6 """ filename_glob = "South_America_Soybean_*.*" filename_regex = r"South_America_Soybean_(?P\d{4})" zipfile_glob = "SouthAmericaSoybean.zip" - date_format = "%Y" is_image = False - url = "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_" - md5 = "7f1d06a57cc6c4ae6be3b3fb9464ddeb" - md5s = { - 2001: "2914b0af7590a0ca4dfa9ccefc99020f", - 2002: "8a4a9dcea54b3ec7de07657b9f2c0893", - 2003: "cad5ed461ff4ab45c90177841aaecad2", - 2004: "f9882ca9c70e054e50172835cb75a8c3", - 2005: "89faae27f9b5afbd06935a465e5fe414", - 2006: "eabaa525414ecbff89301d3d5c706f0b", - 2007: "bb8549b6674163fe20ffd47ec4ce8903", - 2008: "96fc3f737ab3ce9bcd16cbf7761427e2", - 2009: "341387c1bb42a15140c80702e4cca02d", - 2010: "9264532d36ffa93493735a6e44caef0d", - 2011: "b73352ebea3d5658959e9044ec526143", - 2012: "9f3a71097c9836fcff18a13b9ba608b2", - 2013: "0263e19b3cae6fdaba4e3b450cef985e", - 2014: "824ff91c62a4ba9f4ccfd281729830e5", - 2015: "6beb96a61fe0e9ce8c06263e500dde8f", - 2016: "770c558f6ac40550d0e264da5e44b3e", - 2017: "4d0487ac1105d171e5f506f1766ea777", - 2018: "503c2d0a803c2a2629ebbbd9558a3013", - 2019: "441836493bbcd5e123cff579a58f5a4f", - 2020: "0709dec807f576c9707c8c7e183db31", - 2021: "edff3ada13a1a9910d1fe844d28ae4f", - - } - def __init__( self, @@ -84,21 +60,19 @@ def __init__( ) -> None: """Initialize a new Dataset instance. Args: - root: root directory where dataset can be found + paths: one or more root directories to search or files to load crs: :term:`coordinate reference system (CRS)` to warp to (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS (defaults to the resolution of the first file found) - years: list of years to use transforms: a function/transform that takes an input sample and returns a transformed version cache: if True, cache file handle to speed up repeated sampling download: if True, download dataset and store it in the root directory checksum: if True, check the MD5 after downloading files (may be slow) + Raises: - FileNotFoundError: if no files are found in ``root`` RuntimeError: if ``download=False`` but dataset is missing or checksum fails - AssertionError: if ``year`` is invalid """ self.paths = paths self.download = download @@ -121,10 +95,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: return sample def _verify(self) -> None: - """Verify the integrity of the dataset. - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ + """Verify the integrity of the dataset.""" # Check if the extracted files already exist if self.files: return @@ -134,7 +105,6 @@ def _verify(self) -> None: self._extract() return - # Check if the user requested to download the dataset if not self.download: raise DatasetNotFoundError(self) @@ -142,6 +112,7 @@ def _verify(self) -> None: # Download the dataset self._download() self._extract() + def _download(self) -> None: """Download the dataset.""" filename = "SouthAmericaSoybean.zip" @@ -149,6 +120,7 @@ def _download(self) -> None: download_url( self.url, self.paths, filename, md5=self.md5 if self.checksum else None ) + def _extract(self) -> None: """Extract the dataset.""" assert isinstance(self.paths, str) @@ -157,7 +129,6 @@ def _extract(self) -> None: extract_archive(glob.glob(pathname, recursive=True)[0], self.paths) - def plot( self, sample: dict[str, Any], @@ -200,6 +171,4 @@ def plot( plt.suptitle(suptitle) return fig - - - + \ No newline at end of file From ddecb3ce2cc931ccb8331d7731070d2aa25c2dca Mon Sep 17 00:00:00 2001 From: cookie-kyu Date: Thu, 25 Jan 2024 23:24:36 -0600 Subject: [PATCH 52/72] added dataset to datasets.rst --- docs/api/datasets.rst | 6 ++++++ tests/data/south_america_soybean/data.py | 22 -------------------- tests/datasets/test_south_america_soybean.py | 2 ++ 3 files changed, 8 insertions(+), 22 deletions(-) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index a6aa00307e5..2e8d047cccc 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -151,6 +151,12 @@ Sentinel .. autoclass:: Sentinel1 .. autoclass:: Sentinel2 +South America Soybean +^^^^ + +.. autoclass:: SouthAmericaSoybean + + .. _Non-geospatial Datasets: Non-geospatial Datasets diff --git a/tests/data/south_america_soybean/data.py b/tests/data/south_america_soybean/data.py index 8c7c3afe2c7..153f5a951a3 100644 --- a/tests/data/south_america_soybean/data.py +++ b/tests/data/south_america_soybean/data.py @@ -2,7 +2,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#os.environ['PROJ_LIB'] = r'E:\Programs\anaconda3\envs\gis\Library\share\proj' import hashlib import os import shutil @@ -13,27 +12,6 @@ from rasterio.transform import Affine SIZE = 32 -wkt = """ -PROJCS["Albers Conical Equal Area", - GEOGCS["WGS 84", - DATUM["WGS_1984", - SPHEROID["WGS 84",6378137,298.257223563, - AUTHORITY["EPSG","7030"]], - AUTHORITY["EPSG","6326"]], - PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]], - UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]], - AUTHORITY["EPSG","4326"]], - PROJECTION["Albers_Conic_Equal_Area"], - PARAMETER["latitude_of_center",23], - PARAMETER["longitude_of_center",-96], - PARAMETER["standard_parallel_1",29.5], - PARAMETER["standard_parallel_2",45.5], - PARAMETER["false_easting",0], - PARAMETER["false_northing",0], - UNIT["meters",1], - AXIS["Easting",EAST], - AXIS["Northing",NORTH]] -""" np.random.seed(0) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 29085721371..11effb0bef2 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. import os import shutil from pathlib import Path From 7d8eefc1c79f2b95fbba212587827f7de85c8d84 Mon Sep 17 00:00:00 2001 From: cookie-kyu Date: Thu, 25 Jan 2024 23:38:25 -0600 Subject: [PATCH 53/72] edit datasets.rst --- docs/api/datasets.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 2e8d047cccc..42b6d3bc89c 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -151,7 +151,7 @@ Sentinel .. autoclass:: Sentinel1 .. autoclass:: Sentinel2 -South America Soybean +SouthAmericaSoybean ^^^^ .. autoclass:: SouthAmericaSoybean From 243d4881d274bd1d2792b5f1c56786321ee98687 Mon Sep 17 00:00:00 2001 From: cookie-kyu Date: Fri, 26 Jan 2024 11:12:28 -0600 Subject: [PATCH 54/72] pushed again --- docs/api/datasets.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 42b6d3bc89c..7eec8099430 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -150,8 +150,9 @@ Sentinel .. autoclass:: Sentinel .. autoclass:: Sentinel1 .. autoclass:: Sentinel2 + -SouthAmericaSoybean +South America Soybean ^^^^ .. autoclass:: SouthAmericaSoybean From 66d0e93146add2a33baf4c13a349988deb5fe42c Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Fri, 26 Jan 2024 11:25:02 -0600 Subject: [PATCH 55/72] Delete tests/data/south_america_soybean/.DS_Store --- tests/data/south_america_soybean/.DS_Store | Bin 6148 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tests/data/south_america_soybean/.DS_Store diff --git a/tests/data/south_america_soybean/.DS_Store b/tests/data/south_america_soybean/.DS_Store deleted file mode 100644 index 5008ddfcf53c02e82d7eee2e57c38e5672ef89f6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0 Date: Fri, 26 Jan 2024 12:51:33 -0600 Subject: [PATCH 56/72] Update docs/api/datasets.rst Co-authored-by: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> --- docs/api/datasets.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 7eec8099430..84ebee5da68 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -153,7 +153,7 @@ Sentinel South America Soybean -^^^^ +^^^^^^^^^^^^^^^^^^ .. autoclass:: SouthAmericaSoybean From 3b9d60692219b4e0debcc34d0d9f49d6b17f1fe6 Mon Sep 17 00:00:00 2001 From: cookie-kyu Date: Fri, 26 Jan 2024 13:03:12 -0600 Subject: [PATCH 57/72] Edited datasets.rst --- docs/api/datasets.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 84ebee5da68..ab1b4ddcb42 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -150,10 +150,10 @@ Sentinel .. autoclass:: Sentinel .. autoclass:: Sentinel1 .. autoclass:: Sentinel2 - + South America Soybean -^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: SouthAmericaSoybean From 83af8827ece4e37d89c72e36f1caa79f311792a6 Mon Sep 17 00:00:00 2001 From: cookie-kyu Date: Fri, 26 Jan 2024 13:14:42 -0600 Subject: [PATCH 58/72] Edited datasets.rst --- docs/api/datasets.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index ab1b4ddcb42..e3971eee4c1 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -152,7 +152,7 @@ Sentinel .. autoclass:: Sentinel2 -South America Soybean +South America Soybean ^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: SouthAmericaSoybean From c90048847b10e6a75222c7147e253f7c90975aea Mon Sep 17 00:00:00 2001 From: cookie-kyu Date: Fri, 26 Jan 2024 17:15:25 -0600 Subject: [PATCH 59/72] Fixed styling --- tests/data/south_america_soybean/data.py | 2 ++ tests/datasets/test_south_america_soybean.py | 31 +++++++++++++------- torchgeo/datasets/south_america_soybean.py | 26 ++++++++-------- 3 files changed, 37 insertions(+), 22 deletions(-) diff --git a/tests/data/south_america_soybean/data.py b/tests/data/south_america_soybean/data.py index 153f5a951a3..88a9273f110 100644 --- a/tests/data/south_america_soybean/data.py +++ b/tests/data/south_america_soybean/data.py @@ -17,6 +17,7 @@ np.random.seed(0) files = ["South_America_Soybean_2002.tif", "South_America_Soybean_2021.tif"] + def create_file(path: str, dtype: str): """Create the testing file.""" profile = { @@ -45,6 +46,7 @@ def create_file(path: str, dtype: str): with rasterio.open(path, "w", **profile) as src: src.write(Z, 1) + if __name__ == "__main__": dir = os.path.join(os.getcwd(), "SouthAmericaSoybean") print(dir) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 11effb0bef2..15810b3fb9d 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -12,24 +12,33 @@ from rasterio.crs import CRS import torchgeo.datasets.utils -from torchgeo.datasets import SouthAmericaSoybean, BoundingBox, IntersectionDataset, UnionDataset, DatasetNotFoundError +from torchgeo.datasets import ( + BoundingBox, + DatasetNotFoundError, + IntersectionDataset, + SouthAmericaSoybean, + UnionDataset, +) def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: shutil.copy(url, root) + class TestSouthAmericaSoybean: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SouthAmericaSoybean: - monkeypatch.setattr(torchgeo.datasets.south_america_soybean, "download_url", download_url) + monkeypatch.setattr( + torchgeo.datasets.south_america_soybean, "download_url", download_url + ) transforms = nn.Identity() - url = os.path.join("tests", "data", "south_america_soybean", "SouthAmericaSoybean.zip") + url = os.path.join( + "tests", "data", "south_america_soybean", "SouthAmericaSoybean.zip" + ) monkeypatch.setattr(SouthAmericaSoybean, "url", url) root = str(tmp_path) - return SouthAmericaSoybean(paths=root, - transforms=transforms, - download=True, - checksum=True, + return SouthAmericaSoybean( + paths=root, transforms=transforms, download=True, checksum=True ) def test_getitem(self, dataset: SouthAmericaSoybean) -> None: @@ -50,11 +59,13 @@ def test_already_extracted(self, dataset: SouthAmericaSoybean) -> None: SouthAmericaSoybean(dataset.paths, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: - pathname = os.path.join("tests","data", "south_america_soybean", "SouthAmericaSoybean.zip") + pathname = os.path.join( + "tests", "data", "south_america_soybean", "SouthAmericaSoybean.zip" + ) root = str(tmp_path) shutil.copy(pathname, root) SouthAmericaSoybean(root) - + def test_plot(self, dataset: SouthAmericaSoybean) -> None: query = dataset.bounds x = dataset[query] @@ -77,4 +88,4 @@ def test_invalid_query(self, dataset: SouthAmericaSoybean) -> None: with pytest.raises( IndexError, match="query: .* not found in index with bounds:" ): - dataset[query] \ No newline at end of file + dataset[query] diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 93fa2c097a1..d20b39d64ed 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -1,14 +1,13 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -"""South America Soybean Dataset""" +"""South America Soybean Dataset.""" import glob import os from collections.abc import Iterable from typing import Any, Callable, Optional, Union -import torch import matplotlib.pyplot as plt from matplotlib.figure import Figure from rasterio.crs import CRS @@ -18,7 +17,7 @@ class SouthAmericaSoybean(RasterDataset): - """South America Soybean Dataset + """South America Soybean Dataset. This dataset produced annual 30-m soybean maps of South America from 2001 to 2021. @@ -30,15 +29,16 @@ class SouthAmericaSoybean(RasterDataset): Dataset Format: - * 21 .tif files - + * 21 .tif files + If you use this dataset in your research, please use the corresponding citation: - * https://doi.org/10.1038/s41893-021-00729-z + * https://doi.org/10.1038/s41893-021-00729-z .. versionadded:: 0.6 """ + filename_glob = "South_America_Soybean_*.*" filename_regex = r"South_America_Soybean_(?P\d{4})" zipfile_glob = "SouthAmericaSoybean.zip" @@ -59,6 +59,7 @@ def __init__( checksum: bool = False, ) -> None: """Initialize a new Dataset instance. + Args: paths: one or more root directories to search or files to load crs: :term:`coordinate reference system (CRS)` to warp to @@ -80,9 +81,10 @@ def __init__( self._verify() super().__init__(paths, crs, res, transforms=transforms, cache=cache) - + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve mask and metadata indexed by query. + Args: query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index Returns: @@ -91,7 +93,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: IndexError: if query is not found in the index """ sample = super().__getitem__(query) - + return sample def _verify(self) -> None: @@ -116,7 +118,7 @@ def _verify(self) -> None: def _download(self) -> None: """Download the dataset.""" filename = "SouthAmericaSoybean.zip" - + download_url( self.url, self.paths, filename, md5=self.md5 if self.checksum else None ) @@ -124,9 +126,9 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" assert isinstance(self.paths, str) - + pathname = os.path.join(self.paths, "**", self.zipfile_glob) - + extract_archive(glob.glob(pathname, recursive=True)[0], self.paths) def plot( @@ -136,6 +138,7 @@ def plot( suptitle: Optional[str] = None, ) -> Figure: """Plot a sample from the dataset. + Args: sample: a sample returned by :meth:`RasterDataset.__getitem__` show_titles: flag indicating whether to show titles above each panel @@ -171,4 +174,3 @@ def plot( plt.suptitle(suptitle) return fig - \ No newline at end of file From d147b7f5696261c3843f3b1d0afbd0f5a978de63 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 27 Jan 2024 05:05:33 -0600 Subject: [PATCH 60/72] Fix docstring formatting --- torchgeo/datasets/south_america_soybean.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index d20b39d64ed..237565a502b 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -24,6 +24,7 @@ class SouthAmericaSoybean(RasterDataset): Link: https://www.nature.com/articles/s41893-021-00729-z Dataset contains 2 classes: + 0: nodata 1: soybean From 0326926300198df0fe17d22afd4d92fa1e955cb8 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 27 Jan 2024 05:09:10 -0600 Subject: [PATCH 61/72] Fix whitespace --- docs/api/geo_datasets.csv | 2 +- torchgeo/datasets/south_america_soybean.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/api/geo_datasets.csv b/docs/api/geo_datasets.csv index e0378ace348..cd9caae5e0f 100644 --- a/docs/api/geo_datasets.csv +++ b/docs/api/geo_datasets.csv @@ -24,4 +24,4 @@ Dataset,Type,Source,License,Size (px),Resolution (m) `Open Buildings`_,Geometries,"Maxar, CNES/Airbus","CC-BY-4.0 OR ODbL-1.0",-,- `PRISMA`_,Imagery,PRISMA,-,512x512,5--30 `Sentinel`_,Imagery,Sentinel,"CC-BY-SA-3.0-IGO","10,000x10,000",10 -`South America Soybean`_,Masks,Sentinel-2,-,-,10 \ No newline at end of file +`South America Soybean`_,Masks,Sentinel-2,-,-,10 diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 237565a502b..fb1e989c975 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -24,7 +24,6 @@ class SouthAmericaSoybean(RasterDataset): Link: https://www.nature.com/articles/s41893-021-00729-z Dataset contains 2 classes: - 0: nodata 1: soybean @@ -139,7 +138,6 @@ def plot( suptitle: Optional[str] = None, ) -> Figure: """Plot a sample from the dataset. - Args: sample: a sample returned by :meth:`RasterDataset.__getitem__` show_titles: flag indicating whether to show titles above each panel From f7398676a20f0a6942f9a5874aba79513286dfb3 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 27 Jan 2024 05:10:21 -0600 Subject: [PATCH 62/72] Add blank line --- torchgeo/datasets/south_america_soybean.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index fb1e989c975..329072ae3f5 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -138,6 +138,7 @@ def plot( suptitle: Optional[str] = None, ) -> Figure: """Plot a sample from the dataset. + Args: sample: a sample returned by :meth:`RasterDataset.__getitem__` show_titles: flag indicating whether to show titles above each panel From abe019ac106a4e853f7ec120c35bfee687863183 Mon Sep 17 00:00:00 2001 From: cookie-kyu Date: Sat, 27 Jan 2024 18:39:59 -0600 Subject: [PATCH 63/72] Fixed download urls --- tests/datasets/test_south_america_soybean.py | 27 +++++-- torchgeo/datasets/south_america_soybean.py | 85 ++++++++++++++------ 2 files changed, 82 insertions(+), 30 deletions(-) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 15810b3fb9d..65be04cbccd 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -32,10 +32,23 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SouthAmericaSoybe torchgeo.datasets.south_america_soybean, "download_url", download_url ) transforms = nn.Identity() - url = os.path.join( - "tests", "data", "south_america_soybean", "SouthAmericaSoybean.zip" - ) - monkeypatch.setattr(SouthAmericaSoybean, "url", url) + urls = [ + os.path.join( + "tests", + "data", + "south_america_soybean", + "SouthAmericaSoybean", + "South_America_Soybean_2002.tif", + ), + os.path.join( + "tests", + "data", + "south_america_soybean", + "SouthAmericaSoybean", + "South_America_Soybean_2021.tif", + ), + ] + monkeypatch.setattr(SouthAmericaSoybean, "urls", urls) root = str(tmp_path) return SouthAmericaSoybean( paths=root, transforms=transforms, download=True, checksum=True @@ -60,7 +73,11 @@ def test_already_extracted(self, dataset: SouthAmericaSoybean) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join( - "tests", "data", "south_america_soybean", "SouthAmericaSoybean.zip" + "tests", + "data", + "south_america_soybean", + "SouthAmericaSoybean", + "South_America_Soybean_2002.tif", ) root = str(tmp_path) shutil.copy(pathname, root) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index d20b39d64ed..e16e2fb3c32 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -3,8 +3,7 @@ """South America Soybean Dataset.""" -import glob -import os +import re from collections.abc import Iterable from typing import Any, Callable, Optional, Union @@ -13,7 +12,7 @@ from rasterio.crs import CRS from .geo import RasterDataset -from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive +from .utils import BoundingBox, DatasetNotFoundError, download_url class SouthAmericaSoybean(RasterDataset): @@ -41,12 +40,55 @@ class SouthAmericaSoybean(RasterDataset): filename_glob = "South_America_Soybean_*.*" filename_regex = r"South_America_Soybean_(?P\d{4})" - zipfile_glob = "SouthAmericaSoybean.zip" date_format = "%Y" is_image = False - url = "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_" - md5 = "7f1d06a57cc6c4ae6be3b3fb9464ddeb" + urls = [ + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2001.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2002.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2003.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2004.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2005.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2006.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2007.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2008.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2009.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2010.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2011.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2012.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2013.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2014.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2015.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2016.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2017.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2018.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2019.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2020.tif", + "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2021.tif", + ] + md5s = { + 2001: "2914b0af7590a0ca4dfa9ccefc99020f", + 2002: "8a4a9dcea54b3ec7de07657b9f2c0893", + 2003: "cad5ed461ff4ab45c90177841aaecad2", + 2004: "f9882ca9c70e054e50172835cb75a8c3", + 2005: "89faae27f9b5afbd06935a465e5fe414", + 2006: "eabaa525414ecbff89301d3d5c706f0b", + 2007: "bb8549b6674163fe20ffd47ec4ce8903", + 2008: "96fc3f737ab3ce9bcd16cbf7761427e2", + 2009: "341387c1bb42a15140c80702e4cca02d", + 2010: "9264532d36ffa93493735a6e44caef0d", + 2011: "b73352ebea3d5658959e9044ec526143", + 2012: "9f3a71097c9836fcff18a13b9ba608b2", + 2013: "0263e19b3cae6fdaba4e3b450cef985e", + 2014: "824ff91c62a4ba9f4ccfd281729830e5", + 2015: "6beb96a61fe0e9ce8c06263e500dde8f", + 2016: "770c558f6ac40550d0e264da5e44b3e", + 2017: "4d0487ac1105d171e5f506f1766ea777", + 2018: "503c2d0a803c2a2629ebbbd9558a3013", + 2019: "441836493bbcd5e123cff579a58f5a4f", + 2020: "0709dec807f576c9707c8c7e183db31", + 2021: "edff3ada13a1a9910d1fe844d28ae4f", + } def __init__( self, @@ -102,10 +144,6 @@ def _verify(self) -> None: if self.files: return assert isinstance(self.paths, str) - pathname = os.path.join(self.paths, "**", self.zipfile_glob) - if glob.glob(pathname, recursive=True): - self._extract() - return # Check if the user requested to download the dataset if not self.download: @@ -113,23 +151,20 @@ def _verify(self) -> None: # Download the dataset self._download() - self._extract() def _download(self) -> None: """Download the dataset.""" - filename = "SouthAmericaSoybean.zip" - - download_url( - self.url, self.paths, filename, md5=self.md5 if self.checksum else None - ) - - def _extract(self) -> None: - """Extract the dataset.""" - assert isinstance(self.paths, str) - - pathname = os.path.join(self.paths, "**", self.zipfile_glob) - - extract_archive(glob.glob(pathname, recursive=True)[0], self.paths) + file = "SouthAmerica_Soybean_" + num = r"\d+" + for i in range(len(self.urls)): + year = int(re.findall(num, self.urls[i])[0]) + filename = (file + "%d.tif") % year + download_url( + self.urls[i], + self.paths, + filename, + md5=self.md5s[year] if self.checksum else None, + ) def plot( self, @@ -138,7 +173,7 @@ def plot( suptitle: Optional[str] = None, ) -> Figure: """Plot a sample from the dataset. - + Args: sample: a sample returned by :meth:`RasterDataset.__getitem__` show_titles: flag indicating whether to show titles above each panel From f130c0c4483dd89b049a1eb8b4895424db0ad964 Mon Sep 17 00:00:00 2001 From: cookie-kyu Date: Sat, 27 Jan 2024 19:04:39 -0600 Subject: [PATCH 64/72] Update geo_datasets.csv --- docs/api/geo_datasets.csv | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api/geo_datasets.csv b/docs/api/geo_datasets.csv index cd9caae5e0f..2707ed40706 100644 --- a/docs/api/geo_datasets.csv +++ b/docs/api/geo_datasets.csv @@ -24,4 +24,4 @@ Dataset,Type,Source,License,Size (px),Resolution (m) `Open Buildings`_,Geometries,"Maxar, CNES/Airbus","CC-BY-4.0 OR ODbL-1.0",-,- `PRISMA`_,Imagery,PRISMA,-,512x512,5--30 `Sentinel`_,Imagery,Sentinel,"CC-BY-SA-3.0-IGO","10,000x10,000",10 -`South America Soybean`_,Masks,Sentinel-2,-,-,10 +`South America Soybean`_,Masks,Sentinel-2,-,-,2 From 50c84604224ff8c308b2af44f47724e5b0c6c235 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Mon, 29 Jan 2024 10:18:31 -0600 Subject: [PATCH 65/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Adam J. Stewart --- torchgeo/datasets/south_america_soybean.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index e16e2fb3c32..1c867811fa3 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -178,6 +178,7 @@ def plot( sample: a sample returned by :meth:`RasterDataset.__getitem__` show_titles: flag indicating whether to show titles above each panel suptitle: optional string to use as a suptitle + Returns: a matplotlib Figure with the rendered sample """ From d265ea7b0f039cb956f909a2f2d1822a1c50ad8a Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Mon, 29 Jan 2024 10:19:38 -0600 Subject: [PATCH 66/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Adam J. Stewart --- torchgeo/datasets/south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 1c867811fa3..843fdc83ffb 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -31,7 +31,7 @@ class SouthAmericaSoybean(RasterDataset): * 21 .tif files - If you use this dataset in your research, please use the corresponding citation: + If you use this dataset in your research, please cite the following paper: * https://doi.org/10.1038/s41893-021-00729-z From 96af9e54cf6555814c74f3d060ef75ad43fb3efd Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Mon, 29 Jan 2024 10:19:51 -0600 Subject: [PATCH 67/72] Update torchgeo/datasets/south_america_soybean.py Co-authored-by: Adam J. Stewart --- torchgeo/datasets/south_america_soybean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 843fdc83ffb..1485ed7f5e8 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -115,7 +115,7 @@ def __init__( checksum: if True, check the MD5 after downloading files (may be slow) Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails + DatasetNotFoundError: If dataset is not found and *download* is False. """ self.paths = paths self.download = download From 773bd41612795c312099e6e83295aa222ba7cf11 Mon Sep 17 00:00:00 2001 From: cookie-kyu Date: Wed, 31 Jan 2024 16:09:58 -0600 Subject: [PATCH 68/72] Updated geo_datasets.csv and added years parameter to class --- docs/api/geo_datasets.csv | 2 +- tests/data/south_america_soybean/data.py | 1 - tests/datasets/test_south_america_soybean.py | 32 ++++----- torchgeo/datasets/south_america_soybean.py | 73 ++++++++------------ 4 files changed, 43 insertions(+), 65 deletions(-) diff --git a/docs/api/geo_datasets.csv b/docs/api/geo_datasets.csv index 2707ed40706..ed7655e843d 100644 --- a/docs/api/geo_datasets.csv +++ b/docs/api/geo_datasets.csv @@ -24,4 +24,4 @@ Dataset,Type,Source,License,Size (px),Resolution (m) `Open Buildings`_,Geometries,"Maxar, CNES/Airbus","CC-BY-4.0 OR ODbL-1.0",-,- `PRISMA`_,Imagery,PRISMA,-,512x512,5--30 `Sentinel`_,Imagery,Sentinel,"CC-BY-SA-3.0-IGO","10,000x10,000",10 -`South America Soybean`_,Masks,Sentinel-2,-,-,2 +`South America Soybean`_,Masks,"Landsat, MODIS",-,-,30 diff --git a/tests/data/south_america_soybean/data.py b/tests/data/south_america_soybean/data.py index 88a9273f110..fbe7d7b23d1 100644 --- a/tests/data/south_america_soybean/data.py +++ b/tests/data/south_america_soybean/data.py @@ -49,7 +49,6 @@ def create_file(path: str, dtype: str): if __name__ == "__main__": dir = os.path.join(os.getcwd(), "SouthAmericaSoybean") - print(dir) if os.path.exists(dir) and os.path.isdir(dir): shutil.rmtree(dir) diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 65be04cbccd..11dcc2b5ff9 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -32,26 +32,22 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SouthAmericaSoybe torchgeo.datasets.south_america_soybean, "download_url", download_url ) transforms = nn.Identity() - urls = [ - os.path.join( - "tests", - "data", - "south_america_soybean", - "SouthAmericaSoybean", - "South_America_Soybean_2002.tif", - ), - os.path.join( - "tests", - "data", - "south_america_soybean", - "SouthAmericaSoybean", - "South_America_Soybean_2021.tif", - ), - ] - monkeypatch.setattr(SouthAmericaSoybean, "urls", urls) + url = os.path.join( + "tests", + "data", + "south_america_soybean", + "SouthAmericaSoybean", + "South_America_Soybean_{}.tif", + ) + + monkeypatch.setattr(SouthAmericaSoybean, "url", url) root = str(tmp_path) return SouthAmericaSoybean( - paths=root, transforms=transforms, download=True, checksum=True + paths=root, + years=[2002, 2021], + transforms=transforms, + download=True, + checksum=True, ) def test_getitem(self, dataset: SouthAmericaSoybean) -> None: diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index e16e2fb3c32..523d356eb99 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -3,7 +3,6 @@ """South America Soybean Dataset.""" -import re from collections.abc import Iterable from typing import Any, Callable, Optional, Union @@ -12,7 +11,7 @@ from rasterio.crs import CRS from .geo import RasterDataset -from .utils import BoundingBox, DatasetNotFoundError, download_url +from .utils import DatasetNotFoundError, download_url class SouthAmericaSoybean(RasterDataset): @@ -43,28 +42,29 @@ class SouthAmericaSoybean(RasterDataset): date_format = "%Y" is_image = False - urls = [ - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2001.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2002.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2003.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2004.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2005.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2006.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2007.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2008.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2009.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2010.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2011.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2012.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2013.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2014.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2015.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2016.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2017.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2018.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2019.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2020.tif", - "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_2021.tif", + url = "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_{}.tif" + years = [ + 2001, + 2002, + 2003, + 2004, + 2005, + 2006, + 2007, + 2008, + 2009, + 2010, + 2011, + 2012, + 2013, + 2014, + 2015, + 2016, + 2017, + 2018, + 2019, + 2020, + 2021, ] md5s = { 2001: "2914b0af7590a0ca4dfa9ccefc99020f", @@ -93,6 +93,7 @@ class SouthAmericaSoybean(RasterDataset): def __init__( self, paths: Union[str, Iterable[str]] = "data", + years: list[int] = years, crs: Optional[CRS] = None, res: Optional[float] = None, transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, @@ -120,24 +121,11 @@ def __init__( self.paths = paths self.download = download self.checksum = checksum + self.years = years self._verify() super().__init__(paths, crs, res, transforms=transforms, cache=cache) - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: - """Retrieve mask and metadata indexed by query. - - Args: - query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index - Returns: - sample of mask and metadata at that index - Raises: - IndexError: if query is not found in the index - """ - sample = super().__getitem__(query) - - return sample - def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the extracted files already exist @@ -154,15 +142,10 @@ def _verify(self) -> None: def _download(self) -> None: """Download the dataset.""" - file = "SouthAmerica_Soybean_" - num = r"\d+" - for i in range(len(self.urls)): - year = int(re.findall(num, self.urls[i])[0]) - filename = (file + "%d.tif") % year + for year in self.years: download_url( - self.urls[i], + self.url.format(year), self.paths, - filename, md5=self.md5s[year] if self.checksum else None, ) From d84a52a2214691474209e8fdd7ca4cbe9b9aa809 Mon Sep 17 00:00:00 2001 From: Jingtong <115182031+cookie-kyu@users.noreply.github.com> Date: Wed, 31 Jan 2024 16:41:49 -0600 Subject: [PATCH 69/72] Delete tests/data/.DS_Store --- tests/data/.DS_Store | Bin 6148 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tests/data/.DS_Store diff --git a/tests/data/.DS_Store b/tests/data/.DS_Store deleted file mode 100644 index 179e7aa16ad09e16257778afe4e84406ec02afc8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK%}N6?5T3MEw-liV1-%8l7HmsZ#LH6k1zgdC%DQWdE^arayR}Fu>|I~T7x8(V zNm8*?4;~aNGcfs*`AIfkmiz#KXb(FLfGPkusDyleWduBESij}YRT29SiKv>(PQ#H^RFA4h3cufK{yv9z?j z;;c9o=gPZNBQFcGVcHG`7t}gcDh{XhAUun@Bfq+KsFEy*lI}n!M4c{%T%JcsM~&KQ zm~>K|>luercFKNrb3AVB)N8W2-<;IsxY5|%tI1YtXEG@}>s#9gC;olhOVpF$S>P|J zW!2&op0Vbkv2NraFB!7vJ1ncycAe0tei@8Dcpa_$SXi|lJVhEFtc4^~Wi@8CQ4nmKN{n(L( zeW3_FI@+ZU2jLpzk{MtIW*I1&W|hwW5Op}5@OdkR+cQH-&46mOtvLA#_1qH8fXh!zyS2xuC(U Date: Wed, 31 Jan 2024 16:42:07 -0600 Subject: [PATCH 70/72] Delete tests/.DS_Store --- tests/.DS_Store | Bin 6148 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tests/.DS_Store diff --git a/tests/.DS_Store b/tests/.DS_Store deleted file mode 100644 index 5008ddfcf53c02e82d7eee2e57c38e5672ef89f6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0 Date: Mon, 5 Feb 2024 21:32:29 -0600 Subject: [PATCH 71/72] Update south_america_soybean.py --- torchgeo/datasets/south_america_soybean.py | 67 ++++++++-------------- 1 file changed, 23 insertions(+), 44 deletions(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 811d8086bcc..c722528953d 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -43,59 +43,37 @@ class SouthAmericaSoybean(RasterDataset): date_format = "%Y" is_image = False url = "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_{}.tif" - years = [ - 2001, - 2002, - 2003, - 2004, - 2005, - 2006, - 2007, - 2008, - 2009, - 2010, - 2011, - 2012, - 2013, - 2014, - 2015, - 2016, - 2017, - 2018, - 2019, - 2020, - 2021, - ] + md5s = { - 2001: "2914b0af7590a0ca4dfa9ccefc99020f", - 2002: "8a4a9dcea54b3ec7de07657b9f2c0893", - 2003: "cad5ed461ff4ab45c90177841aaecad2", - 2004: "f9882ca9c70e054e50172835cb75a8c3", - 2005: "89faae27f9b5afbd06935a465e5fe414", - 2006: "eabaa525414ecbff89301d3d5c706f0b", - 2007: "bb8549b6674163fe20ffd47ec4ce8903", - 2008: "96fc3f737ab3ce9bcd16cbf7761427e2", - 2009: "341387c1bb42a15140c80702e4cca02d", - 2010: "9264532d36ffa93493735a6e44caef0d", - 2011: "b73352ebea3d5658959e9044ec526143", - 2012: "9f3a71097c9836fcff18a13b9ba608b2", - 2013: "0263e19b3cae6fdaba4e3b450cef985e", - 2014: "824ff91c62a4ba9f4ccfd281729830e5", - 2015: "6beb96a61fe0e9ce8c06263e500dde8f", - 2016: "770c558f6ac40550d0e264da5e44b3e", - 2017: "4d0487ac1105d171e5f506f1766ea777", - 2018: "503c2d0a803c2a2629ebbbd9558a3013", - 2019: "441836493bbcd5e123cff579a58f5a4f", - 2020: "0709dec807f576c9707c8c7e183db31", 2021: "edff3ada13a1a9910d1fe844d28ae4f", + 2020: "0709dec807f576c9707c8c7e183db31", + 2019: "441836493bbcd5e123cff579a58f5a4f", + 2018: "503c2d0a803c2a2629ebbbd9558a3013", + 2017: "4d0487ac1105d171e5f506f1766ea777", + 2016: "770c558f6ac40550d0e264da5e44b3e", + 2015: "6beb96a61fe0e9ce8c06263e500dde8f", + 2014: "824ff91c62a4ba9f4ccfd281729830e5", + 2013: "0263e19b3cae6fdaba4e3b450cef985e", + 2012: "9f3a71097c9836fcff18a13b9ba608b2", + 2011: "b73352ebea3d5658959e9044ec526143", + 2010: "9264532d36ffa93493735a6e44caef0d", + 2009: "341387c1bb42a15140c80702e4cca02d", + 2008: "96fc3f737ab3ce9bcd16cbf7761427e2", + 2007: "bb8549b6674163fe20ffd47ec4ce8903", + 2006: "eabaa525414ecbff89301d3d5c706f0b", + 2005: "89faae27f9b5afbd06935a465e5fe414", + 2004: "f9882ca9c70e054e50172835cb75a8c3", + 2003: "cad5ed461ff4ab45c90177841aaecad2", + 2002: "8a4a9dcea54b3ec7de07657b9f2c0893", + 2001: "2914b0af7590a0ca4dfa9ccefc99020f", } def __init__( self, paths: Union[str, Iterable[str]] = "data", - years: list[int] = years, crs: Optional[CRS] = None, res: Optional[float] = None, + years: list[int] = [2021], transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, download: bool = False, @@ -109,6 +87,7 @@ def __init__( (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS (defaults to the resolution of the first file found) + years: list of years for which to use the South America Soybean layer transforms: a function/transform that takes an input sample and returns a transformed version cache: if True, cache file handle to speed up repeated sampling From 477ddd695093873deb24a41a688ccb258497e378 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 6 Feb 2024 03:52:29 -0600 Subject: [PATCH 72/72] Fix docstring formatting --- torchgeo/datasets/south_america_soybean.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index c722528953d..e1d1e952cf0 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -22,8 +22,9 @@ class SouthAmericaSoybean(RasterDataset): Link: https://www.nature.com/articles/s41893-021-00729-z Dataset contains 2 classes: - 0: nodata - 1: soybean + + 0. other + 1. soybean Dataset Format: