From dd30f0a0cb3f1044e5d667e9aad141401380f703 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Sat, 3 Dec 2022 17:58:56 +0100 Subject: [PATCH] add crop logic to potsdam --- tests/datamodules/test_potsdam.py | 14 ++- torchgeo/datamodules/potsdam.py | 136 ++++++++++++++++++++++++++---- torchgeo/datasets/utils.py | 18 ++++ 3 files changed, 150 insertions(+), 18 deletions(-) diff --git a/tests/datamodules/test_potsdam.py b/tests/datamodules/test_potsdam.py index 5a8ad1f785f..16273d60187 100644 --- a/tests/datamodules/test_potsdam.py +++ b/tests/datamodules/test_potsdam.py @@ -15,7 +15,7 @@ class TestPotsdam2DDataModule: @pytest.fixture(scope="class", params=[0.0, 0.5]) def datamodule(self, request: SubRequest) -> Potsdam2DDataModule: root = os.path.join("tests", "data", "potsdam") - batch_size = 1 + batch_size = 2 num_workers = 0 val_split_size = request.param dm = Potsdam2DDataModule( @@ -23,11 +23,23 @@ def datamodule(self, request: SubRequest) -> Potsdam2DDataModule: batch_size=batch_size, num_workers=num_workers, val_split_pct=val_split_size, + num_tiles_per_batch=1, ) dm.prepare_data() dm.setup() return dm + def test_batch_size_warning(self, datamodule: Potsdam2DDataModule) -> None: + match = "The effective batch size will differ" + with pytest.warns(UserWarning, match=match): + Potsdam2DDataModule( + root=datamodule.test_dataset.root, + batch_size=3, + num_tiles_per_batch=2, + num_workers=datamodule.num_workers, + val_split_pct=datamodule.val_split_pct, + ) + def test_train_dataloader(self, datamodule: Potsdam2DDataModule) -> None: next(iter(datamodule.train_dataloader())) diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 7721fd15a70..832f4cc4d6a 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -3,13 +3,19 @@ """Potsdam datamodule.""" -from typing import Any, Dict, Optional +import warnings +from typing import Any, Dict, Optional, Tuple, Union +import kornia.augmentation as K import matplotlib.pyplot as plt import pytorch_lightning as pl +import torch from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose +from torchgeo.datasets.utils import collate_patches_per_tile +from torchgeo.samplers.utils import _to_tuple + from ..datasets import Potsdam2D from .utils import dataset_split @@ -27,6 +33,8 @@ def __init__( batch_size: int = 64, num_workers: int = 0, val_split_pct: float = 0.2, + patch_size: Union[Tuple[int, int], int] = (64, 64), + num_tiles_per_batch: int = 16, **kwargs: Any, ) -> None: """Initialize a LightningDataModule for Potsdam2D based DataLoaders. @@ -35,15 +43,47 @@ def __init__( batch_size: The batch size to use in all created DataLoaders num_workers: The number of workers to use in all created DataLoaders val_split_pct: What percentage of the dataset to use as a validation set + patch_size: Size of random patch from image and mask (height, width), should + be a multiple of 32 for most segmentation architectures + num_tiles_per_batch: number of random tiles to consider sampling patches + from per sample, should evenly divide batch_size and be less than + or equal to batch_size **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.Potsdam2D` + + .. versionchanged:: 0.4 + 'patch_size' and 'num_tiles_per_batch' introduced in order to randomly + crop the variable size images during training """ super().__init__() self.batch_size = batch_size self.num_workers = num_workers + self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct self.kwargs = kwargs + assert ( + self.batch_size >= num_tiles_per_batch + ), "num_tiles_per_batch should be less than or equal to batch_size." + + self.num_patches_per_tile = self.batch_size // num_tiles_per_batch + self.num_tiles_per_batch = num_tiles_per_batch + + if (self.num_patches_per_tile % 2) != 0 and ( + self.num_patches_per_tile != num_tiles_per_batch + ): + warnings.warn( + "The effective batch size" + f" will differ from the specified {batch_size}" + f" and be {self.num_patches_per_tile * num_tiles_per_batch} instead." + " To match the batch_size exactly, ensure that" + " num_tiles_per_batch evenly divides batch_size" + ) + + self.rcrop = K.AugmentationSequential( + K.RandomCrop(self.patch_size), data_keys=["input", "mask"] + ) + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset. @@ -64,27 +104,79 @@ def setup(self, stage: Optional[str] = None) -> None: Args: stage: stage to set up + + .. versionchanged:: 0.4 + Add functionality to randomly crop patches from a tile during + training and pad validation and test samples to next multiple of 32 """ - transforms = Compose([self.preprocess]) - dataset = Potsdam2D(split="train", transforms=transforms, **self.kwargs) + def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]: + """Construct 'num_patches_per_tile' random patches of input tile. + + Args: + sample: contains image and mask tile from dataset + + Returns: + stacked randomly cropped patches from input tile + """ + images, masks = [], [] + for i in range(self.num_patches_per_tile): + image, mask = self.rcrop(sample["image"], sample["mask"].float()) + images.append(image.squeeze(0)) + masks.append(mask.squeeze().long()) + + sample["image"] = torch.stack(images) + sample["mask"] = torch.stack(masks) + return sample + + def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]: + """Pad image and mask to next multiple of 32. + + Args: + sample: contains image and mask sample from dataset + + Returns: + padded image and mask + """ + h, w = sample["image"].shape[1], sample["image"].shape[2] + new_h = int(32 * ((h // 32) + 1)) + new_w = int(32 * ((w // 32) + 1)) + + padto = K.PadTo((new_h, new_w)) + + sample["image"] = padto(sample["image"])[0] + sample["mask"] = padto(sample["mask"].float()).long()[0, 0] + return sample + + train_transforms = Compose([self.preprocess, n_random_crop]) + # for testing and validation we pad all inputs to next larger multiple of 32 + # to avoid issues with upsampling paths in encoder-decoder architectures + test_transforms = Compose([self.preprocess, pad_to]) + + train_dataset = Potsdam2D( + split="train", transforms=train_transforms, **self.kwargs + ) self.train_dataset: Dataset[Any] self.val_dataset: Dataset[Any] if self.val_split_pct > 0.0: + val_dataset = Potsdam2D( + split="train", transforms=test_transforms, **self.kwargs + ) self.train_dataset, self.val_dataset, _ = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=0.0 + train_dataset, val_pct=self.val_split_pct, test_pct=0.0 ) + self.val_dataset.dataset = val_dataset else: - self.train_dataset = dataset - self.val_dataset = dataset + self.train_dataset = train_dataset + self.val_dataset = train_dataset self.test_dataset = Potsdam2D( - split="test", transforms=transforms, **self.kwargs + split="test", transforms=test_transforms, **self.kwargs ) - def train_dataloader(self) -> DataLoader[Any]: + def train_dataloader(self) -> DataLoader[Dict[str, Any]]: """Return a DataLoader for training. Returns: @@ -92,25 +184,35 @@ def train_dataloader(self) -> DataLoader[Any]: """ return DataLoader( self.train_dataset, - batch_size=self.batch_size, + batch_size=self.num_tiles_per_batch, num_workers=self.num_workers, + collate_fn=collate_patches_per_tile, shuffle=True, ) - def val_dataloader(self) -> DataLoader[Any]: + def val_dataloader(self) -> DataLoader[Dict[str, Any]]: """Return a DataLoader for validation. Returns: validation data loader """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) + if self.val_split_pct > 0.0: + return DataLoader( + self.val_dataset, + batch_size=1, + num_workers=self.num_workers, + shuffle=False, + ) + else: + return DataLoader( + self.val_dataset, + batch_size=1, + num_workers=self.num_workers, + shuffle=False, + collate_fn=collate_patches_per_tile, + ) - def test_dataloader(self) -> DataLoader[Any]: + def test_dataloader(self) -> DataLoader[Dict[str, Any]]: """Return a DataLoader for testing. Returns: diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 3270356f022..ebefe3e7d59 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -30,7 +30,9 @@ import numpy as np import rasterio import torch +from einops import rearrange from torch import Tensor +from torch.utils.data._utils.collate import default_collate from torchvision.datasets.utils import check_integrity, download_url from torchvision.utils import draw_segmentation_masks @@ -216,6 +218,22 @@ def download_radiant_mlhub_collection( collection.download(output_dir=download_root, api_key=api_key) +def collate_patches_per_tile(batch: List[Dict[str, Any]]) -> Dict[str, Any]: + """Define collate function to combine patches per tile and batch size. + + Args: + batch: sample batch from dataloader containing image and mask + + Returns: + sample batch where the batch dimension is + 'train_batch_size' * 'num_patches_per_tile' + """ + r_batch: Dict[str, Any] = default_collate(batch) # type: ignore[no-untyped-call] + r_batch["image"] = rearrange(r_batch["image"], "b t c h w -> (b t) c h w") + r_batch["mask"] = rearrange(r_batch["mask"], "b t h w -> (b t) h w") + return r_batch + + @dataclass(frozen=True) class BoundingBox: """Data class for indexing spatiotemporal data."""