diff --git a/conf/deepglobelandcover.yaml b/conf/deepglobelandcover.yaml new file mode 100644 index 00000000000..7732406da45 --- /dev/null +++ b/conf/deepglobelandcover.yaml @@ -0,0 +1,21 @@ +experiment: + task: "deepglobelandcover" + module: + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 7 + num_filters: 1 + ignore_index: null + datamodule: + root: "data/deepglobelandcover" + num_tiles_per_batch: 16 + num_patches_per_tile: 16 + patch_size: 64 + val_split_pct: 0.5 + num_workers: 0 diff --git a/conf/vaihingen2d.yaml b/conf/vaihingen2d.yaml new file mode 100644 index 00000000000..0f3015faf66 --- /dev/null +++ b/conf/vaihingen2d.yaml @@ -0,0 +1,21 @@ +experiment: + task: "vaihingen2d" + module: + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 7 + num_filters: 1 + ignore_index: null + datamodule: + root: "data/vaihingen" + num_tiles_per_batch: 16 + num_patches_per_tile: 16 + patch_size: 64 + val_split_pct: 0.5 + num_workers: 0 diff --git a/tests/conf/vaihingen2d.yaml b/tests/conf/vaihingen2d.yaml new file mode 100644 index 00000000000..0184b929a0a --- /dev/null +++ b/tests/conf/vaihingen2d.yaml @@ -0,0 +1,21 @@ +experiment: + task: "vaihingen2d" + module: + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 7 + num_filters: 1 + ignore_index: null + datamodule: + root: "tests/data/vaihingen" + num_tiles_per_batch: 1 + num_patches_per_tile: 1 + patch_size: 2 + val_split_pct: 0.5 + num_workers: 0 diff --git a/tests/datamodules/test_vaihingen.py b/tests/datamodules/test_vaihingen.py deleted file mode 100644 index 13fd2d52e4c..00000000000 --- a/tests/datamodules/test_vaihingen.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os - -import matplotlib.pyplot as plt -import pytest -from _pytest.fixtures import SubRequest - -from torchgeo.datamodules import Vaihingen2DDataModule -from torchgeo.datasets import unbind_samples - - -class TestVaihingen2DDataModule: - @pytest.fixture(scope="class", params=[0.0, 0.5]) - def datamodule(self, request: SubRequest) -> Vaihingen2DDataModule: - root = os.path.join("tests", "data", "vaihingen") - batch_size = 1 - num_workers = 0 - val_split_size = request.param - dm = Vaihingen2DDataModule( - root=root, - batch_size=batch_size, - num_workers=num_workers, - val_split_pct=val_split_size, - ) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: - next(iter(datamodule.test_dataloader())) - - def test_plot(self, datamodule: Vaihingen2DDataModule) -> None: - batch = next(iter(datamodule.train_dataloader())) - sample = unbind_samples(batch)[0] - datamodule.plot(sample) - plt.close() diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 6c9b1f2ff69..1ffb3931516 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -21,6 +21,7 @@ NAIPChesapeakeDataModule, SEN12MSDataModule, SpaceNet1DataModule, + Vaihingen2DDataModule, ) from torchgeo.datasets import LandCoverAI from torchgeo.trainers import SemanticSegmentationTask @@ -50,6 +51,7 @@ class TestSemanticSegmentationTask: ("sen12ms_s2_all", SEN12MSDataModule), ("sen12ms_s2_reduced", SEN12MSDataModule), ("spacenet1", SpaceNet1DataModule), + ("vaihingen2d", Vaihingen2DDataModule), ], ) def test_trainer( diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index ac71f69bd77..bd25f0a7c1c 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -7,7 +7,9 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl +from einops import rearrange from kornia.augmentation import Normalize +from torch import Tensor from torch.utils.data import DataLoader from ..datasets import DeepGlobeLandCover @@ -90,7 +92,7 @@ def setup(self, stage: Optional[str] = None) -> None: ) self.test_dataset = DeepGlobeLandCover(split="test", **self.kwargs) - def train_dataloader(self) -> DataLoader[Dict[str, Any]]: + def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for training. Returns: @@ -103,7 +105,7 @@ def train_dataloader(self) -> DataLoader[Dict[str, Any]]: shuffle=True, ) - def val_dataloader(self) -> DataLoader[Dict[str, Any]]: + def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for validation. Returns: @@ -113,7 +115,7 @@ def val_dataloader(self) -> DataLoader[Dict[str, Any]]: self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) - def test_dataloader(self) -> DataLoader[Dict[str, Any]]: + def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for testing. Returns: @@ -124,8 +126,8 @@ def test_dataloader(self) -> DataLoader[Dict[str, Any]]: ) def on_after_batch_transfer( - self, batch: Dict[str, Any], dataloader_idx: int - ) -> Dict[str, Any]: + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: """Apply augmentations to batch after transferring to GPU. Args: @@ -136,7 +138,7 @@ def on_after_batch_transfer( A batch of data """ # Kornia requires masks to have a channel dimension - batch["mask"] = batch["mask"].unsqueeze(1) + batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") if self.trainer: if self.trainer.training: @@ -145,7 +147,7 @@ def on_after_batch_transfer( batch = self.test_transform(batch) # Torchmetrics does not support masks with a channel dimension - batch["mask"] = batch["mask"].squeeze(1) + batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") return batch diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 6017b1cce77..aa1499a6742 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -3,14 +3,19 @@ """Vaihingen datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple, Union import matplotlib.pyplot as plt import pytorch_lightning as pl -from torch.utils.data import DataLoader, Dataset -from torchvision.transforms import Compose +from einops import rearrange +from kornia.augmentation import Normalize +from torch import Tensor +from torch.utils.data import DataLoader from ..datasets import Vaihingen2D +from ..samplers.utils import _to_tuple +from ..transforms import AugmentationSequential +from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop from .utils import dataset_split @@ -24,67 +29,72 @@ class Vaihingen2DDataModule(pl.LightningDataModule): def __init__( self, - batch_size: int = 64, - num_workers: int = 0, + num_tiles_per_batch: int = 16, + num_patches_per_tile: int = 16, + patch_size: Union[Tuple[int, int], int] = 64, val_split_pct: float = 0.2, + num_workers: int = 0, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for Vaihingen2D based DataLoaders. + """Initialize a new LightningDataModule instance. + + The Vaihingen2D dataset contains images that are too large to pass + directly through a model. Instead, we randomly sample patches from image tiles + during training and chop up image tiles into patch grids during evaluation. + During training, the effective batch size is equal to + ``num_tiles_per_batch`` x ``num_patches_per_tile``. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set + num_tiles_per_batch: The number of image tiles to sample from during + training + num_patches_per_tile: The number of patches to randomly sample from each + image tile during training + patch_size: The size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures + val_split_pct: The percentage of the dataset to use as a validation set + num_workers: The number of workers to use for parallel data loading **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.Vaihingen2D` + + .. versionchanged:: 0.4 + *batch_size* was replaced by *num_tile_per_batch*, *num_patches_per_tile*, + and *patch_size*. """ super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers + + self.num_tiles_per_batch = num_tiles_per_batch + self.num_patches_per_tile = num_patches_per_tile + self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct + self.num_workers = num_workers self.kwargs = kwargs - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample + self.train_transform = AugmentationSequential( + Normalize(mean=0.0, std=255.0), + _RandomNCrop(self.patch_size, self.num_patches_per_tile), + data_keys=["image", "mask"], + ) + self.test_transform = AugmentationSequential( + Normalize(mean=0.0, std=255.0), + _ExtractTensorPatches(self.patch_size), + data_keys=["image", "mask"], + ) def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. + """Initialize the main Dataset objects. This method is called once per GPU per run. Args: stage: stage to set up """ - transforms = Compose([self.preprocess]) - - dataset = Vaihingen2D(split="train", transforms=transforms, **self.kwargs) - - self.train_dataset: Dataset[Any] - self.val_dataset: Dataset[Any] - - if self.val_split_pct > 0.0: - self.train_dataset, self.val_dataset, _ = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=0.0 - ) - else: - self.train_dataset = dataset - self.val_dataset = dataset - - self.test_dataset = Vaihingen2D( - split="test", transforms=transforms, **self.kwargs + train_dataset = Vaihingen2D(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + train_dataset, self.val_split_pct ) + self.test_dataset = Vaihingen2D(split="test", **self.kwargs) - def train_dataloader(self) -> DataLoader[Any]: + def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for training. Returns: @@ -92,37 +102,57 @@ def train_dataloader(self) -> DataLoader[Any]: """ return DataLoader( self.train_dataset, - batch_size=self.batch_size, + batch_size=self.num_tiles_per_batch, num_workers=self.num_workers, shuffle=True, ) - def val_dataloader(self) -> DataLoader[Any]: + def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for validation. Returns: validation data loader """ return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, + self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) - def test_dataloader(self) -> DataLoader[Any]: + def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for testing. Returns: testing data loader """ return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, + self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + # Kornia requires masks to have a channel dimension + batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") + + if self.trainer: + if self.trainer.training: + batch = self.train_transform(batch) + elif self.trainer.validating or self.trainer.testing: + batch = self.test_transform(batch) + + # Torchmetrics does not support masks with a channel dimension + batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") + + return batch + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.Vaihingen2D.plot`. diff --git a/torchgeo/datasets/vaihingen.py b/torchgeo/datasets/vaihingen.py index 4945247f574..8509607f2f0 100644 --- a/torchgeo/datasets/vaihingen.py +++ b/torchgeo/datasets/vaihingen.py @@ -188,7 +188,7 @@ def _load_image(self, index: int) -> Tensor: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW - tensor = tensor.permute((2, 0, 1)) + tensor = tensor.permute((2, 0, 1)).float() return tensor def _load_target(self, index: int) -> Tensor: diff --git a/train.py b/train.py index cb001437686..5bad176717d 100755 --- a/train.py +++ b/train.py @@ -17,6 +17,7 @@ BigEarthNetDataModule, ChesapeakeCVPRDataModule, COWCCountingDataModule, + DeepGlobeLandCoverDataModule, ETCI2021DataModule, EuroSATDataModule, InriaAerialImageLabelingDataModule, @@ -30,6 +31,7 @@ SpaceNet1DataModule, TropicalCycloneDataModule, UCMercedDataModule, + Vaihingen2DDataModule, ) from torchgeo.trainers import ( BYOLTask, @@ -48,6 +50,7 @@ "chesapeake_cvpr": (SemanticSegmentationTask, ChesapeakeCVPRDataModule), "cowc_counting": (RegressionTask, COWCCountingDataModule), "cyclone": (RegressionTask, TropicalCycloneDataModule), + "deepglobelandcover": (SemanticSegmentationTask, DeepGlobeLandCoverDataModule), "eurosat": (ClassificationTask, EuroSATDataModule), "etci2021": (SemanticSegmentationTask, ETCI2021DataModule), "inria": (SemanticSegmentationTask, InriaAerialImageLabelingDataModule), @@ -60,6 +63,7 @@ "so2sat": (ClassificationTask, So2SatDataModule), "spacenet1": (SemanticSegmentationTask, SpaceNet1DataModule), "ucmerced": (ClassificationTask, UCMercedDataModule), + "vaihingen2d": (SemanticSegmentationTask, Vaihingen2DDataModule), }