diff --git a/tests/datasets/test_cms_mangrove_canopy.py b/tests/datasets/test_cms_mangrove_canopy.py index 54a41d75fb5..3cbe4173e89 100644 --- a/tests/datasets/test_cms_mangrove_canopy.py +++ b/tests/datasets/test_cms_mangrove_canopy.py @@ -90,4 +90,10 @@ def test_or(self, dataset: CMSGlobalMangroveCanopy) -> None: def test_plot(self, dataset: CMSGlobalMangroveCanopy) -> None: query = dataset.bounds x = dataset[query] - dataset.plot(x["mask"]) + dataset.plot(x, suptitle="Test") + + def test_plot_prediction(self, dataset: CMSGlobalMangroveCanopy) -> None: + query = dataset.bounds + x = dataset[query] + x["prediction"] = x["mask"].clone() + dataset.plot(x, suptitle="Prediction") diff --git a/torchgeo/datasets/cms_mangrove_canopy.py b/torchgeo/datasets/cms_mangrove_canopy.py index a9484bd6412..a2e1a2cfe8b 100644 --- a/torchgeo/datasets/cms_mangrove_canopy.py +++ b/torchgeo/datasets/cms_mangrove_canopy.py @@ -7,6 +7,7 @@ import os from typing import Any, Callable, Dict, Optional +import matplotlib.pyplot as plt from rasterio.crs import CRS from .geo import RasterDataset @@ -249,3 +250,48 @@ def _extract(self) -> None: """Extract the dataset.""" pathname = os.path.join(self.root, self.zipfile) extract_archive(pathname) + + def plot( # type: ignore[override] + self, + sample: Dict[str, Any], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`RasterDataset.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + mask = sample["mask"].squeeze() + ncols = 1 + + showing_predictions = "prediction" in sample + if showing_predictions: + pred = sample["prediction"].squeeze() + ncols = 2 + + fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * 4, 4)) + + if showing_predictions: + axs[0].imshow(mask) + axs[0].axis("off") + axs[1].imshow(pred) + axs[1].axis("off") + if show_titles: + axs[0].set_title("Mask") + axs[1].set_title("Prediction") + else: + axs.imshow(mask) + axs.axis("off") + if show_titles: + axs.set_title("Mask") + + if suptitle is not None: + plt.suptitle(suptitle) + + return