-
Notifications
You must be signed in to change notification settings - Fork 388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add random crop logic to DeepGlobeLandCover Datamodule #876
Changes from 39 commits
ac16986
d74bae9
d464034
2fa4f93
304b253
64c2243
9504cf5
e2569f6
e90d926
ae22fc2
da023ac
368fbbb
780d2f0
0ed41e9
80747d9
383f40f
0ec2ce0
4fa0a1e
ad39c6b
9cf99c1
c3e2325
873c9af
5b0c6b0
e976630
621a905
5ab6535
ae17d7d
f2a0dd8
d23fc7c
5412449
4e3a1b0
041aaca
d31c48d
9af4d43
32b610c
03d96f3
e680000
2857b7b
d35773b
50ece12
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,87 +3,92 @@ | |
|
||
"""DeepGlobe Land Cover Classification Challenge datamodule.""" | ||
|
||
from typing import Any, Dict, Optional | ||
from typing import Any, Dict, Optional, Tuple, Union | ||
|
||
import matplotlib.pyplot as plt | ||
import pytorch_lightning as pl | ||
from torch.utils.data import DataLoader, Dataset | ||
from torchvision.transforms import Compose | ||
from kornia.augmentation import Normalize | ||
Comment on lines
-11
to
+10
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm planning on removing all torchvision transforms. Torchvision relies on PIL for many of its transforms, which doesn't support MSI. Kornia has all of the same transforms, but they are in pure PyTorch, so they can run on the GPU and support MSI. I don't see a good reason not to only use Kornia transforms. |
||
from torch.utils.data import DataLoader | ||
|
||
from ..datasets import DeepGlobeLandCover | ||
from ..samplers.utils import _to_tuple | ||
from ..transforms import AugmentationSequential | ||
from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop | ||
from .utils import dataset_split | ||
|
||
|
||
class DeepGlobeLandCoverDataModule(pl.LightningDataModule): | ||
"""LightningDataModule implementation for the DeepGlobe Land Cover dataset. | ||
|
||
Uses the train/test splits from the dataset. | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
batch_size: int = 64, | ||
num_workers: int = 0, | ||
num_tiles_per_batch: int = 16, | ||
num_patches_per_tile: int = 16, | ||
patch_size: Union[Tuple[int, int], int] = 64, | ||
val_split_pct: float = 0.2, | ||
num_workers: int = 0, | ||
**kwargs: Any, | ||
) -> None: | ||
"""Initialize a LightningDataModule for DeepGlobe Land Cover based DataLoaders. | ||
"""Initialize a new LightningDataModule instance. | ||
|
||
The DeepGlobe Land Cover dataset contains images that are too large to pass | ||
directly through a model. Instead, we randomly sample patches from image tiles | ||
during training and chop up image tiles into patch grids during evaluation. | ||
During training, the effective batch size is equal to | ||
``num_tiles_per_batch`` x ``num_patches_per_tile``. | ||
|
||
Args: | ||
batch_size: The batch size to use in all created DataLoaders | ||
num_workers: The number of workers to use in all created DataLoaders | ||
val_split_pct: What percentage of the dataset to use as a validation set | ||
num_tiles_per_batch: The number of image tiles to sample from during | ||
training | ||
num_patches_per_tile: The number of patches to randomly sample from each | ||
image tile during training | ||
patch_size: The size of each patch, either ``size`` or ``(height, width)``. | ||
Should be a multiple of 32 for most segmentation architectures | ||
val_split_pct: The percentage of the dataset to use as a validation set | ||
num_workers: The number of workers to use for parallel data loading | ||
**kwargs: Additional keyword arguments passed to | ||
:class:`~torchgeo.datasets.DeepGlobeLandCover` | ||
|
||
adamjstewart marked this conversation as resolved.
Show resolved
Hide resolved
|
||
.. versionchanged:: 0.4 | ||
*batch_size* was replaced by *num_tile_per_batch*, *num_patches_per_tile*, | ||
and *patch_size*. | ||
Comment on lines
+56
to
+57
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tend to only document API changes, not internal changes. So the fact that we're now using random cropping isn't documented, only that the parameters changed. |
||
""" | ||
super().__init__() | ||
self.batch_size = batch_size | ||
self.num_workers = num_workers | ||
|
||
self.num_tiles_per_batch = num_tiles_per_batch | ||
self.num_patches_per_tile = num_patches_per_tile | ||
self.patch_size = _to_tuple(patch_size) | ||
self.val_split_pct = val_split_pct | ||
self.num_workers = num_workers | ||
self.kwargs = kwargs | ||
|
||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We cannot use instance methods as transforms, see #886 for what happens when we do. |
||
"""Transform a single sample from the Dataset. | ||
|
||
Args: | ||
sample: input image dictionary | ||
|
||
Returns: | ||
preprocessed sample | ||
""" | ||
sample["image"] = sample["image"].float() | ||
sample["image"] /= 255.0 | ||
return sample | ||
self.train_transform = AugmentationSequential( | ||
Normalize(mean=0.0, std=255.0), | ||
_RandomNCrop(self.patch_size, self.num_patches_per_tile), | ||
data_keys=["image", "mask"], | ||
) | ||
self.test_transform = AugmentationSequential( | ||
Normalize(mean=0.0, std=255.0), | ||
_ExtractTensorPatches(self.patch_size), | ||
data_keys=["image", "mask"], | ||
) | ||
|
||
def setup(self, stage: Optional[str] = None) -> None: | ||
"""Initialize the main ``Dataset`` objects. | ||
"""Initialize the main Dataset objects. | ||
|
||
This method is called once per GPU per run. | ||
|
||
Args: | ||
stage: stage to set up | ||
""" | ||
transforms = Compose([self.preprocess]) | ||
|
||
dataset = DeepGlobeLandCover( | ||
split="train", transforms=transforms, **self.kwargs | ||
) | ||
|
||
self.train_dataset: Dataset[Any] | ||
self.val_dataset: Dataset[Any] | ||
|
||
if self.val_split_pct > 0.0: | ||
self.train_dataset, self.val_dataset, _ = dataset_split( | ||
dataset, val_pct=self.val_split_pct, test_pct=0.0 | ||
) | ||
else: | ||
self.train_dataset = dataset | ||
self.val_dataset = dataset | ||
Comment on lines
-73
to
-82
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have no idea why our previous logic was so complicated, but I don't think it needs to be. |
||
|
||
self.test_dataset = DeepGlobeLandCover( | ||
split="test", transforms=transforms, **self.kwargs | ||
train_dataset = DeepGlobeLandCover(split="train", **self.kwargs) | ||
adamjstewart marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.train_dataset, self.val_dataset = dataset_split( | ||
train_dataset, self.val_split_pct | ||
) | ||
self.test_dataset = DeepGlobeLandCover(split="test", **self.kwargs) | ||
adamjstewart marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def train_dataloader(self) -> DataLoader[Dict[str, Any]]: | ||
"""Return a DataLoader for training. | ||
|
@@ -93,7 +98,7 @@ def train_dataloader(self) -> DataLoader[Dict[str, Any]]: | |
""" | ||
return DataLoader( | ||
self.train_dataset, | ||
batch_size=self.batch_size, | ||
batch_size=self.num_tiles_per_batch, | ||
num_workers=self.num_workers, | ||
shuffle=True, | ||
) | ||
|
@@ -105,10 +110,7 @@ def val_dataloader(self) -> DataLoader[Dict[str, Any]]: | |
validation data loader | ||
""" | ||
return DataLoader( | ||
self.val_dataset, | ||
batch_size=self.batch_size, | ||
num_workers=self.num_workers, | ||
shuffle=False, | ||
self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False | ||
) | ||
|
||
def test_dataloader(self) -> DataLoader[Dict[str, Any]]: | ||
|
@@ -118,12 +120,32 @@ def test_dataloader(self) -> DataLoader[Dict[str, Any]]: | |
testing data loader | ||
""" | ||
return DataLoader( | ||
self.test_dataset, | ||
batch_size=self.batch_size, | ||
num_workers=self.num_workers, | ||
shuffle=False, | ||
self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False | ||
) | ||
|
||
def on_after_batch_transfer( | ||
self, batch: Dict[str, Any], dataloader_idx: int | ||
) -> Dict[str, Any]: | ||
"""Apply augmentations to batch after transferring to GPU. | ||
|
||
Args: | ||
batch: A batch of data that needs to be altered or augmented | ||
dataloader_idx: The index of the dataloader to which the batch belongs | ||
|
||
Returns: | ||
A batch of data | ||
""" | ||
if self.trainer: | ||
if self.trainer.training: | ||
Comment on lines
+141
to
+142
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So much cleaner than our previous logic! |
||
batch = self.train_transform(batch) | ||
elif self.trainer.validating or self.trainer.testing: | ||
batch = self.test_transform(batch) | ||
|
||
# Kornia adds a channel dimension to the mask | ||
batch["mask"] = batch["mask"].squeeze(1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Kornia does a lot of weird stuff with transforms that I don't like. Masks are required to be floats (why? slower, more storage). If the mask you input doesn't have a channel dimension, it will add one. Some of the transforms actually break if the mask doesn't have a channel dimension when you input it, so we may need to add an unsqueeze above. |
||
|
||
return batch | ||
|
||
def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: | ||
"""Run :meth:`torchgeo.datasets.DeepGlobeLandCover.plot`. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,13 +4,23 @@ | |
"""Common sampler utilities.""" | ||
|
||
import math | ||
from typing import Optional, Tuple, Union | ||
from typing import Optional, Tuple, Union, overload | ||
|
||
import torch | ||
|
||
from ..datasets import BoundingBox | ||
|
||
|
||
@overload | ||
def _to_tuple(value: Union[Tuple[int, int], int]) -> Tuple[int, int]: | ||
... | ||
|
||
|
||
@overload | ||
def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: | ||
... | ||
Comment on lines
+14
to
+21
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In Python typing, all ints are floats, but not all floats are ints. This meant that if I pass an int as input, mypy would consider its output type to be float. These overloads ensure that int maps to int and float maps to float as expected. |
||
|
||
|
||
def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: | ||
"""Convert value to a tuple if it is not already a tuple. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Kornia 0.6.5+, all augmentation instance methods have a new
flags
parameter. So the transforms I added won't work with Kornia 0.6.4 and older. Once we upstream these transforms to Kornia, we'll need to depend on an even newer version anyway.