Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Inria datamodule #498

Merged
merged 26 commits into from
May 22, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor
  • Loading branch information
ashnair1 committed May 19, 2022
commit cfe9765b91c3a92a899b9b87f0bcca30abc90a0a
19 changes: 15 additions & 4 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytorch_lightning as pl
import rasterio as rio
import torch
from kornia.contrib import CombineTensorPatches
from omegaconf import OmegaConf

from torchgeo.datamodules import (
Expand Down Expand Up @@ -115,17 +116,27 @@ def main(config_dir: str, predict_on: str, output_dir: str, device: str) -> None
assert len(x.shape) in {4, 5}
if len(x.shape) == 5:
masks = []

def tensor_to_int(
tensor_tuple: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[int, ...]:
"""Convert tuple of tensors to tuple of ints."""
return tuple(int(i.item()) for i in tensor_tuple)

original_shape = tensor_to_int(batch["original_shape"])
patch_shape = tensor_to_int(batch["patch_shape"])
padding = tensor_to_int(batch["padding"])
patch_combine = CombineTensorPatches(
original_size=original_shape, window_size=patch_shape, unpadding=padding
)
for tile in x:
mask = task(tile)
mask = mask.argmax(dim=1)
masks.append(mask)

masks_arr = torch.stack(masks, dim=0)
masks_arr = masks_arr.unsqueeze(0)

if not hasattr(datamodule, "patch_combine"):
raise NotImplementedError
masks_combined = datamodule.patch_combine(masks_arr)[0]
masks_combined = patch_combine(masks_arr)[0]
filename = datamodule.predict_dataset.files[i]["image"]
write_mask(masks_combined, output_dir, filename)
else:
Expand Down
7 changes: 1 addition & 6 deletions tests/datamodules/test_inria.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import pytest
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch

from torchgeo.datamodules import InriaAerialImageLabelingDataModule

Expand All @@ -17,17 +16,13 @@ class TestInriaAerialImageLabelingDataModule:
@pytest.fixture(
params=zip([0.2, 0.2, 0.0], [0.2, 0.0, 0.0], ["test", PREDICT_DATA_DIR, "test"])
)
def datamodule(
self, request: SubRequest, monkeypatch: MonkeyPatch
) -> InriaAerialImageLabelingDataModule:
def datamodule(self, request: SubRequest) -> InriaAerialImageLabelingDataModule:
val_split_pct, test_split_pct, predict_on = request.param
patch_size = 2 # (2,2)
num_patches_per_tile = 2
root = TEST_DATA_DIR
batch_size = 1
num_workers = 0
monkeypatch.setattr(InriaAerialImageLabelingDataModule, "h", 8)
monkeypatch.setattr(InriaAerialImageLabelingDataModule, "w", 8)
dm = InriaAerialImageLabelingDataModule(
root,
batch_size,
Expand Down
49 changes: 18 additions & 31 deletions torchgeo/datamodules/inria.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
import torch
import torchvision.transforms as T
from einops import rearrange
from kornia.contrib import CombineTensorPatches, ExtractTensorPatches
from kornia.contrib import extract_tensor_patches
from torch.nn.modules.utils import _pair
from torch.utils.data import DataLoader, Dataset
from torch.utils.data._utils.collate import default_collate

from torchgeo.datamodules.utils import dataset_split
from torchgeo.datasets import InriaAerialImageLabeling
from torchgeo.datasets.utils import PredictDataset
from torchgeo.datasets.utils import PredictDataset, compute_padding

DEFAULT_AUGS = K.AugmentationSequential(
K.RandomHorizontalFlip(p=0.5),
Expand All @@ -27,23 +27,6 @@
)


# Maybe this can be moved to utils?
def compute_padding(
h: int, w: int, window_size: Union[int, Tuple[int, int]]
) -> Tuple[int, int]:
"""Compute required padding."""
window_size = cast(Tuple[int, int], _pair(window_size))
if (h % window_size[0]) == 0:
h_pad = 0
else:
h_pad = (window_size[0] - (h % window_size[0])) // 2
if (w % window_size[1]) == 0:
w_pad = 0
else:
w_pad = (window_size[1] - (w % window_size[1])) // 2
return (h_pad, w_pad)


def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Flatten wrapper."""
r_batch: Dict[str, Any] = default_collate(batch) # type: ignore[no-untyped-call]
Expand Down Expand Up @@ -106,19 +89,22 @@ def __init__(
K.RandomCrop(self.patch_size, p=1.0, keepdim=False),
data_keys=["input", "mask"],
)
padding = compute_padding(self.h, self.w, self.patch_size)
self.patch_extract = ExtractTensorPatches(
window_size=self.patch_size, stride=self.patch_size, padding=padding
)
self.patch_combine = CombineTensorPatches(
original_size=(self.h, self.w),
window_size=self.patch_size,
unpadding=padding,
)

def patch_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Extract patches from signle sample."""
sample["image"] = self.patch_extract(sample["image"].unsqueeze(0))
"""Extract patches from single sample."""
assert sample["image"].ndim == 3
_, h, w = sample["image"].shape

padding = compute_padding((h, w), self.patch_size)
sample["original_shape"] = (h, w)
sample["patch_shape"] = self.patch_size
sample["padding"] = padding
sample["image"] = extract_tensor_patches(
sample["image"].unsqueeze(0),
self.patch_size,
self.patch_size,
padding=padding,
)
sample["image"] = rearrange(sample["image"], "() t c h w -> t () c h w")
return sample

Expand Down Expand Up @@ -155,6 +141,7 @@ def setup(self, stage: Optional[str] = None) -> None:
"""
train_transforms = T.Compose([self.preprocess, self.n_random_crop])
test_transforms = T.Compose([self.preprocess, self.patch_sample])
predict_transforms = T.Compose([self.preprocess])

train_dataset = InriaAerialImageLabeling(
self.root_dir, split="train", transforms=train_transforms
Expand Down Expand Up @@ -183,7 +170,7 @@ def setup(self, stage: Optional[str] = None) -> None:

if os.path.isdir(self.predict_on):
self.predict_dataset = PredictDataset(
self.predict_on, transforms=test_transforms
self.predict_on, self.patch_size, transforms=predict_transforms
)
else:
assert self.predict_on == "test"
Expand Down
49 changes: 48 additions & 1 deletion torchgeo/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
import numpy as np
import rasterio
import torch
from einops import rearrange
from kornia.contrib import extract_tensor_patches
from torch import Tensor
from torch.nn.modules.utils import _pair
from torch.utils.data import Dataset
from torchvision.datasets.utils import check_integrity, download_url
from torchvision.utils import draw_segmentation_masks
Expand All @@ -57,29 +60,68 @@
)


def compute_padding(
original_size: Union[int, Tuple[int, int]], window_size: Union[int, Tuple[int, int]]
) -> Tuple[int, int]:
"""Compute required padding."""
original_size = cast(Tuple[int, int], _pair(original_size))
window_size = cast(Tuple[int, int], _pair(window_size))

if (original_size[0] % window_size[0]) == 0:
h_pad = 0
else:
h_pad = (window_size[0] - (original_size[0] % window_size[0])) // 2
if (original_size[1] % window_size[1]) == 0:
w_pad = 0
else:
w_pad = (window_size[1] - (original_size[1] % window_size[1])) // 2
return (h_pad, w_pad)


class PredictDataset(Dataset[Any]):
"""Prediction dataset for VisionDatasets."""

def __init__(
self,
root: str,
patch_size: Tuple[int, int],
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
bands: Tuple[int, ...] = (1, 2, 3),
) -> None:
"""Initialize a new PredictDataset instance.

Args:
root: root directory where dataset can be found
patch_size: Size of patch used as input for the model.
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version.
bands: bands to be used.

"""
self.root = root
self.patch_size = patch_size
self.transforms = transforms
self.bands = bands
self.files = self._load_files(root)

def patch_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Extract patches from single sample."""
assert sample["image"].ndim == 3
_, h, w = sample["image"].shape

padding = compute_padding((h, w), self.patch_size)
sample["original_shape"] = (h, w)
sample["patch_shape"] = self.patch_size
sample["padding"] = padding
sample["image"] = extract_tensor_patches(
sample["image"].unsqueeze(0),
self.patch_size,
self.patch_size,
padding=padding,
)
sample["image"] = rearrange(sample["image"], "() t c h w -> t () c h w")
return sample

def _load_files(self, root: str) -> List[Dict[str, str]]:
"""Return the paths of the files in the dataset.

Expand Down Expand Up @@ -127,7 +169,12 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:
img = self._load_image(file["image"])
sample = {"image": img}
if self.transforms is not None:
sample = self.transforms(sample)
self.transforms.transforms.append( # type:ignore[attr-defined]
self.patch_sample
)
else:
self.transforms = self.patch_sample
sample = self.transforms(sample)
return sample


Expand Down