diff --git a/tests/data/levircd/LEVIR-CD+.zip b/tests/data/levircd/LEVIR-CD+.zip index b51dc099207..9a5fa4e1a7c 100644 Binary files a/tests/data/levircd/LEVIR-CD+.zip and b/tests/data/levircd/LEVIR-CD+.zip differ diff --git a/tests/datasets/test_levircd.py b/tests/datasets/test_levircd.py index 2aca6c8b0c5..f61bc241be8 100644 --- a/tests/datasets/test_levircd.py +++ b/tests/datasets/test_levircd.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -31,7 +32,7 @@ def dataset( monkeypatch.setattr( # type: ignore[attr-defined] torchgeo.datasets.utils, "download_url", download_url ) - md5 = "b61c300e9fd7146eb2c8e2512c0e9d39" + md5 = "1adf156f628aa32fb2e8fe6cada16c04" monkeypatch.setattr(LEVIRCDPlus, "md5", md5) # type: ignore[attr-defined] url = os.path.join("tests", "data", "levircd", "LEVIR-CD+.zip") monkeypatch.setattr(LEVIRCDPlus, "url", url) # type: ignore[attr-defined] @@ -60,3 +61,12 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): LEVIRCDPlus(str(tmp_path)) + + def test_plot(self, dataset: LEVIRCDPlus) -> None: + dataset.plot(dataset[0], suptitle="Test") + plt.close() + + sample = dataset[0] + sample["prediction"] = sample["mask"].clone() + dataset.plot(sample, suptitle="Prediction") + plt.close() diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 24d76ca6594..9098a23b585 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -7,6 +7,7 @@ import os from typing import Callable, Dict, List, Optional +import matplotlib.pyplot as plt import numpy as np import torch from PIL import Image @@ -47,6 +48,7 @@ class LEVIRCDPlus(VisionDataset): url = "https://drive.google.com/file/d/1JamSsxiytXdzAIk6VDVWfc-OsX-81U81" md5 = "1adf156f628aa32fb2e8fe6cada16c04" filename = "LEVIR-CD+.zip" + directory = "LEVIR-CD+" splits = ["train", "test"] def __init__( @@ -88,7 +90,7 @@ def __init__( + "You can use download=True to download it" ) - self.files = self._load_files(self.root, self.split) + self.files = self._load_files(self.root, self.directory, self.split) def __getitem__(self, index: int) -> Dict[str, Tensor]: """Return an index within the dataset. @@ -120,23 +122,26 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str, split: str) -> List[Dict[str, str]]: + def _load_files( + self, root: str, directory: str, split: str + ) -> List[Dict[str, str]]: """Return the paths of the files in the dataset. Args: root: root dir of dataset + directory: sub directory LEVIR-CD+ split: subset of dataset, one of [train, test] Returns: list of dicts containing paths for each pair of image1, image2, mask """ files = [] - images = glob.glob(os.path.join(root, split, "A", "*.png")) + images = glob.glob(os.path.join(root, directory, split, "A", "*.png")) images = sorted([os.path.basename(image) for image in images]) for image in images: - image1 = os.path.join(root, split, "A", image) - image2 = os.path.join(root, split, "B", image) - mask = os.path.join(root, split, "label", image) + image1 = os.path.join(root, directory, split, "A", image) + image2 = os.path.join(root, directory, split, "B", image) + mask = os.path.join(root, directory, split, "label", image) files.append(dict(image1=image1, image2=image2, mask=mask)) return files @@ -181,7 +186,7 @@ def _check_integrity(self) -> bool: True if the dataset directories and split files are found, else False """ for filename in self.splits: - filepath = os.path.join(self.root, filename) + filepath = os.path.join(self.root, self.directory, filename) if not os.path.exists(filepath): return False return True @@ -202,3 +207,53 @@ def _download(self) -> None: filename=self.filename, md5=self.md5 if self.checksum else None, ) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional suptitle to use for figure + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + image1, image2, mask = (sample["image"][0], sample["image"][1], sample["mask"]) + ncols = 3 + + if "prediction" in sample: + prediction = sample["prediction"] + ncols += 1 + + fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5)) + + axs[0].imshow(image1.permute(1, 2, 0)) + axs[0].axis("off") + axs[1].imshow(image2.permute(1, 2, 0)) + axs[1].axis("off") + axs[2].imshow(mask) + axs[2].axis("off") + + if "prediction" in sample: + axs[3].imshow(prediction) + axs[3].axis("off") + if show_titles: + axs[3].set_title("Prediction") + + if show_titles: + axs[0].set_title("Image 1") + axs[1].set_title("Image 2") + axs[2].set_title("Mask") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig