Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Vaihingen datamodule #853

Merged
merged 17 commits into from
Dec 30, 2022
21 changes: 21 additions & 0 deletions conf/deepglobelandcover.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
experiment:
task: "deepglobelandcover"
module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 3
num_classes: 7
num_filters: 1
ignore_index: null
datamodule:
root: "data/deepglobelandcover"
num_tiles_per_batch: 16
num_patches_per_tile: 16
patch_size: 64
val_split_pct: 0.5
num_workers: 0
21 changes: 21 additions & 0 deletions conf/vaihingen2d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
experiment:
task: "vaihingen2d"
module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 3
num_classes: 7
num_filters: 1
ignore_index: null
datamodule:
root: "data/vaihingen"
num_tiles_per_batch: 16
num_patches_per_tile: 16
patch_size: 64
val_split_pct: 0.5
num_workers: 0
21 changes: 21 additions & 0 deletions tests/conf/vaihingen2d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
experiment:
task: "vaihingen2d"
module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 3
num_classes: 7
num_filters: 1
ignore_index: null
datamodule:
root: "tests/data/vaihingen"
num_tiles_per_batch: 1
num_patches_per_tile: 1
patch_size: 2
val_split_pct: 0.5
num_workers: 0
44 changes: 0 additions & 44 deletions tests/datamodules/test_vaihingen.py

This file was deleted.

2 changes: 2 additions & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
NAIPChesapeakeDataModule,
SEN12MSDataModule,
SpaceNet1DataModule,
Vaihingen2DDataModule,
)
from torchgeo.datasets import LandCoverAI
from torchgeo.trainers import SemanticSegmentationTask
Expand Down Expand Up @@ -50,6 +51,7 @@ class TestSemanticSegmentationTask:
("sen12ms_s2_all", SEN12MSDataModule),
("sen12ms_s2_reduced", SEN12MSDataModule),
("spacenet1", SpaceNet1DataModule),
("vaihingen2d", Vaihingen2DDataModule),
],
)
def test_trainer(
Expand Down
16 changes: 9 additions & 7 deletions torchgeo/datamodules/deepglobelandcover.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

import matplotlib.pyplot as plt
import pytorch_lightning as pl
from einops import rearrange
from kornia.augmentation import Normalize
from torch import Tensor
from torch.utils.data import DataLoader

from ..datasets import DeepGlobeLandCover
Expand Down Expand Up @@ -90,7 +92,7 @@ def setup(self, stage: Optional[str] = None) -> None:
)
self.test_dataset = DeepGlobeLandCover(split="test", **self.kwargs)

def train_dataloader(self) -> DataLoader[Dict[str, Any]]:
def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]:
"""Return a DataLoader for training.

Returns:
Expand All @@ -103,7 +105,7 @@ def train_dataloader(self) -> DataLoader[Dict[str, Any]]:
shuffle=True,
)

def val_dataloader(self) -> DataLoader[Dict[str, Any]]:
def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]:
"""Return a DataLoader for validation.

Returns:
Expand All @@ -113,7 +115,7 @@ def val_dataloader(self) -> DataLoader[Dict[str, Any]]:
self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False
)

def test_dataloader(self) -> DataLoader[Dict[str, Any]]:
def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]:
"""Return a DataLoader for testing.

Returns:
Expand All @@ -124,8 +126,8 @@ def test_dataloader(self) -> DataLoader[Dict[str, Any]]:
)

def on_after_batch_transfer(
self, batch: Dict[str, Any], dataloader_idx: int
) -> Dict[str, Any]:
self, batch: Dict[str, Tensor], dataloader_idx: int
) -> Dict[str, Tensor]:
"""Apply augmentations to batch after transferring to GPU.

Args:
Expand All @@ -136,7 +138,7 @@ def on_after_batch_transfer(
A batch of data
"""
# Kornia requires masks to have a channel dimension
batch["mask"] = batch["mask"].unsqueeze(1)
batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w")

if self.trainer:
if self.trainer.training:
Expand All @@ -145,7 +147,7 @@ def on_after_batch_transfer(
batch = self.test_transform(batch)

# Torchmetrics does not support masks with a channel dimension
batch["mask"] = batch["mask"].squeeze(1)
batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w")

return batch

Expand Down
136 changes: 83 additions & 53 deletions torchgeo/datamodules/vaihingen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@

"""Vaihingen datamodule."""

from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple, Union

import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose
from einops import rearrange
from kornia.augmentation import Normalize
from torch import Tensor
from torch.utils.data import DataLoader

from ..datasets import Vaihingen2D
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop
from .utils import dataset_split


Expand All @@ -24,105 +29,130 @@ class Vaihingen2DDataModule(pl.LightningDataModule):

def __init__(
self,
batch_size: int = 64,
num_workers: int = 0,
num_tiles_per_batch: int = 16,
num_patches_per_tile: int = 16,
patch_size: Union[Tuple[int, int], int] = 64,
val_split_pct: float = 0.2,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for Vaihingen2D based DataLoaders.
"""Initialize a new LightningDataModule instance.

The Vaihingen2D dataset contains images that are too large to pass
directly through a model. Instead, we randomly sample patches from image tiles
during training and chop up image tiles into patch grids during evaluation.
During training, the effective batch size is equal to
``num_tiles_per_batch`` x ``num_patches_per_tile``.

Args:
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
num_tiles_per_batch: The number of image tiles to sample from during
training
num_patches_per_tile: The number of patches to randomly sample from each
image tile during training
patch_size: The size of each patch, either ``size`` or ``(height, width)``.
Should be a multiple of 32 for most segmentation architectures
val_split_pct: The percentage of the dataset to use as a validation set
num_workers: The number of workers to use for parallel data loading
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.Vaihingen2D`

.. versionchanged:: 0.4
*batch_size* was replaced by *num_tile_per_batch*, *num_patches_per_tile*,
and *patch_size*.
"""
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers

self.num_tiles_per_batch = num_tiles_per_batch
self.num_patches_per_tile = num_patches_per_tile
self.patch_size = _to_tuple(patch_size)
self.val_split_pct = val_split_pct
self.num_workers = num_workers
self.kwargs = kwargs

def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.

Args:
sample: input image dictionary

Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
return sample
self.train_transform = AugmentationSequential(
Normalize(mean=0.0, std=255.0),
_RandomNCrop(self.patch_size, self.num_patches_per_tile),
data_keys=["image", "mask"],
)
self.test_transform = AugmentationSequential(
Normalize(mean=0.0, std=255.0),
_ExtractTensorPatches(self.patch_size),
data_keys=["image", "mask"],
)

def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
"""Initialize the main Dataset objects.

This method is called once per GPU per run.

Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])

dataset = Vaihingen2D(split="train", transforms=transforms, **self.kwargs)

self.train_dataset: Dataset[Any]
self.val_dataset: Dataset[Any]

if self.val_split_pct > 0.0:
self.train_dataset, self.val_dataset, _ = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=0.0
)
else:
self.train_dataset = dataset
self.val_dataset = dataset

self.test_dataset = Vaihingen2D(
split="test", transforms=transforms, **self.kwargs
train_dataset = Vaihingen2D(split="train", **self.kwargs)
self.train_dataset, self.val_dataset = dataset_split(
train_dataset, self.val_split_pct
)
self.test_dataset = Vaihingen2D(split="test", **self.kwargs)

def train_dataloader(self) -> DataLoader[Any]:
def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]:
"""Return a DataLoader for training.

Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
batch_size=self.num_tiles_per_batch,
num_workers=self.num_workers,
shuffle=True,
)

def val_dataloader(self) -> DataLoader[Any]:
def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]:
"""Return a DataLoader for validation.

Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False
)

def test_dataloader(self) -> DataLoader[Any]:
def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]:
"""Return a DataLoader for testing.

Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False
)

def on_after_batch_transfer(
self, batch: Dict[str, Tensor], dataloader_idx: int
) -> Dict[str, Tensor]:
"""Apply augmentations to batch after transferring to GPU.

Args:
batch: A batch of data that needs to be altered or augmented
dataloader_idx: The index of the dataloader to which the batch belongs

Returns:
A batch of data
"""
# Kornia requires masks to have a channel dimension
batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w")

if self.trainer:
if self.trainer.training:
batch = self.train_transform(batch)
elif self.trainer.validating or self.trainer.testing:
batch = self.test_transform(batch)

# Torchmetrics does not support masks with a channel dimension
batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w")

return batch

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.Vaihingen2D.plot`.

Expand Down
Loading