diff --git a/docs/api/geo_datasets.csv b/docs/api/geo_datasets.csv index f420f70cfea..9bee384cd95 100644 --- a/docs/api/geo_datasets.csv +++ b/docs/api/geo_datasets.csv @@ -12,7 +12,7 @@ Dataset,Type,Source,Size (px),Resolution (m) `GBIF`_,Points,Citizen Scientists,-,- `GlobBiomass`_,Masks,Landsat,"45,000x45,000",100 `iNaturalist`_,Points,Citizen Scientists,-,- -`Landsat`_,Imagery,Landsat,-,30 +`Landsat`_,Imagery,Landsat,"8,900x8,900",30 `NAIP`_,Imagery,Aerial,"6,100x7,600",1 `Open Buildings`_,Geometries,"Maxar, CNES/Airbus",-,- `Sentinel`_,Imagery,Sentinel,"10,000x10,000",10 diff --git a/tests/datasets/test_landsat.py b/tests/datasets/test_landsat.py index 59b6a0f61a4..bc1ff2c8aea 100644 --- a/tests/datasets/test_landsat.py +++ b/tests/datasets/test_landsat.py @@ -4,6 +4,7 @@ import os from pathlib import Path +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -38,6 +39,20 @@ def test_or(self, dataset: Landsat8) -> None: ds = dataset | dataset assert isinstance(ds, UnionDataset) + def test_plot(self, dataset: Landsat8) -> None: + x = dataset[dataset.bounds] + dataset.plot(x, suptitle="Test") + plt.close() + + def test_plot_wrong_bands(self, dataset: Landsat8) -> None: + bands = ("SR_B1",) + ds = Landsat8(root=dataset.root, bands=bands) + x = dataset[dataset.bounds] + with pytest.raises( + ValueError, match="Dataset doesn't contain some of the RGB bands" + ): + ds.plot(x) + def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(FileNotFoundError, match="No Landsat8 data was found in "): Landsat8(str(tmp_path)) diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py index 9d9c5242564..de4023fc50b 100644 --- a/torchgeo/datasets/cdl.py +++ b/torchgeo/datasets/cdl.py @@ -413,6 +413,10 @@ def plot( Returns: a matplotlib Figure with the rendered sample + + .. versionchanged:: 0.3 + Method now takes a sample dict, not a Tensor. Additionally, possible to + show subplot titles and/or use a custom suptitle. """ mask = sample["mask"].squeeze().numpy() ncols = 1 diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 7557136d277..59528f9c0fe 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -19,7 +19,6 @@ import torch from matplotlib.colors import ListedColormap from rasterio.crs import CRS -from torch import Tensor from .geo import GeoDataset, RasterDataset from .utils import BoundingBox, download_url, extract_archive @@ -178,7 +177,7 @@ def _extract(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: Dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: @@ -192,7 +191,9 @@ def plot( Returns: a matplotlib Figure with the rendered sample - .. versionadded:: 0.3 + .. versionchanged:: 0.3 + Method now takes a sample dict, not a Tensor. Additionally, possible to + show subplot titles and/or use a custom suptitle. """ mask = sample["mask"].squeeze(0) ncols = 1 diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py index 0f62f0201ad..e3cc15d6faa 100644 --- a/torchgeo/datasets/landsat.py +++ b/torchgeo/datasets/landsat.py @@ -6,6 +6,7 @@ import abc from typing import Any, Callable, Dict, Optional, Sequence +import matplotlib.pyplot as plt from rasterio.crs import CRS from .geo import RasterDataset @@ -78,6 +79,54 @@ def __init__( super().__init__(root, crs, res, transforms, cache) + def plot( + self, + sample: Dict[str, Any], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`RasterDataset.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + Raises: + ValueError: if the RGB bands are not included in ``self.bands`` + + .. versionchanged:: 0.3 + Method now takes a sample dict, not a Tensor. Additionally, possible to + show subplot titles and/or use a custom suptitle. + """ + rgb_indices = [] + for band in self.rgb_bands: + if band in self.bands: + rgb_indices.append(self.bands.index(band)) + else: + raise ValueError("Dataset doesn't contain some of the RGB bands") + + image = sample["image"][rgb_indices].permute(1, 2, 0).float() + + # Stretch to the full range + image = (image - image.min()) / (image.max() - image.min()) + + fig, ax = plt.subplots(1, 1, figsize=(4, 4)) + + ax.imshow(image) + ax.axis("off") + + if show_titles: + ax.set_title("Image") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig + class Landsat1(Landsat): """Landsat 1 Multispectral Scanner (MSS).""" diff --git a/torchgeo/datasets/naip.py b/torchgeo/datasets/naip.py index b3d054a492f..9274b830129 100644 --- a/torchgeo/datasets/naip.py +++ b/torchgeo/datasets/naip.py @@ -64,8 +64,8 @@ def plot( a matplotlib Figure with the rendered sample .. versionchanged:: 0.3 - Method now takes a sample dict, not a Tensor. Additionally, possible to - show subplot titles and/or use a custom suptitle. + Method now takes a sample dict, not a Tensor. Additionally, possible to + show subplot titles and/or use a custom suptitle. """ image = sample["image"][0:3, :, :].permute(1, 2, 0) diff --git a/torchgeo/datasets/sentinel.py b/torchgeo/datasets/sentinel.py index 1f347bd1495..cc8817a8c85 100644 --- a/torchgeo/datasets/sentinel.py +++ b/torchgeo/datasets/sentinel.py @@ -8,7 +8,6 @@ import matplotlib.pyplot as plt import torch from rasterio.crs import CRS -from torch import Tensor from .geo import RasterDataset @@ -104,7 +103,7 @@ def __init__( def plot( self, - sample: Dict[str, Tensor], + sample: Dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: @@ -121,7 +120,9 @@ def plot( Raises: ValueError: if the RGB bands are not included in ``self.bands`` - .. versionadded:: 0.3 + .. versionchanged:: 0.3 + Method now takes a sample dict, not a Tensor. Additionally, possible to + show subplot titles and/or use a custom suptitle. """ rgb_indices = [] for band in self.RGB_BANDS: