Skip to content

Commit

Permalink
Add VHR10 datamodule (#1082)
Browse files Browse the repository at this point in the history
* Add VHR10 datamodule

* Add newline

* patch_size accepts int and tuple of ints

* Update conf

* VHR10 Datamodule v2

* Remove auto_lr_find

* Remove preprocess

* Update config

* Remove setting of matplotlib backend

* Remove import

* Typing update

* Key fix

* Coverage fix

* Update conf

* Update conf

* Dowload=True

* Use weights

* Empty commit

* Switch to ndim

* Remove conf, tight_layout and spacing

* Set constrained layout via rcParams

* Revert and bump min matplotlib version

* Switch back to dataset_split

* Separate out AugPipe

* Increase figsize & revert matplotlib

* Common collate_fn

* Class var std

* Undo std change in BaseDataModule

* Undo req changes

* Remove unused line

* Add version strings

* mypy fix

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
ashnair1 and adamjstewart authored Jan 25, 2024
1 parent b0f5184 commit 0f8b0ac
Show file tree
Hide file tree
Showing 13 changed files with 243 additions and 55 deletions.
17 changes: 17 additions & 0 deletions tests/conf/vhr10.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
model:
class_path: ObjectDetectionTask
init_args:
model: "faster-rcnn"
backbone: "resnet50"
num_classes: 11
lr: 2.5e-5
patience: 10
data:
class_path: VHR10DataModule
init_args:
batch_size: 1
num_workers: 0
patch_size: 4
dict_kwargs:
root: "tests/data/vhr10"
download: true
Binary file modified tests/data/vhr10/NWPU VHR-10 dataset.rar
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/data/vhr10/annotations.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"images": [{"file_name": "001.jpg", "height": 8, "width": 8, "id": 0}, {"file_name": "002.jpg", "height": 8, "width": 8, "id": 1}, {"file_name": "003.jpg", "height": 8, "width": 8, "id": 2}, {"file_name": "004.jpg", "height": 8, "width": 8, "id": 3}, {"file_name": "005.jpg", "height": 8, "width": 8, "id": 4}], "annotations": [{"id": 0, "image_id": 0, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "iscrowd": 0}, {"id": 1, "image_id": 1, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 2, "image_id": 2, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 3, "image_id": 3, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 4, "image_id": 4, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}]}
{"images": [{"file_name": "001.jpg", "height": 8, "width": 8, "id": 0}, {"file_name": "002.jpg", "height": 8, "width": 8, "id": 1}, {"file_name": "003.jpg", "height": 8, "width": 8, "id": 2}, {"file_name": "004.jpg", "height": 8, "width": 8, "id": 3}, {"file_name": "005.jpg", "height": 8, "width": 8, "id": 4}], "annotations": [{"id": 0, "image_id": 0, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 1, "image_id": 1, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 2, "image_id": 2, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 3, "image_id": 3, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 4, "image_id": 4, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}]}
10 changes: 2 additions & 8 deletions tests/data/vhr10/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import shutil
import subprocess
from copy import deepcopy

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -47,7 +46,7 @@ def generate_test_data(root: str, n_imgs: int = 3) -> str:
)

ann = 0
for i, img in enumerate(ANNOTATION_FILE["images"]):
for _, img in enumerate(ANNOTATION_FILE["images"]):
annot = {
"id": ann,
"image_id": img["id"],
Expand All @@ -57,12 +56,7 @@ def generate_test_data(root: str, n_imgs: int = 3) -> str:
"segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]],
"iscrowd": 0,
}
if i != 0:
ANNOTATION_FILE["annotations"].append(annot)
else:
noseg_annot = deepcopy(annot)
del noseg_annot["segmentation"]
ANNOTATION_FILE["annotations"].append(noseg_annot)
ANNOTATION_FILE["annotations"].append(annot)
ann += 1

with open(ann_file, "w") as j:
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_vhr10.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def dataset(
monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url)
url = os.path.join("tests", "data", "vhr10", "NWPU VHR-10 dataset.rar")
monkeypatch.setitem(VHR10.image_meta, "url", url)
md5 = "5fddb0dfd56a80638831df9f90cbf37a"
md5 = "92769845cae6a4e8c74bfa1a0d1d4a80"
monkeypatch.setitem(VHR10.image_meta, "md5", md5)
url = os.path.join("tests", "data", "vhr10", "annotations.json")
monkeypatch.setitem(VHR10.target_meta, "url", url)
md5 = "833899cce369168e0d4ee420dac326dc"
md5 = "567c4cd8c12624864ff04865de504c58"
monkeypatch.setitem(VHR10.target_meta, "md5", md5)
root = str(tmp_path)
split = request.param
Expand Down
2 changes: 1 addition & 1 deletion tests/trainers/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def plot(*args: Any, **kwargs: Any) -> None:


class TestObjectDetectionTask:
@pytest.mark.parametrize("name", ["nasa_marine_debris"])
@pytest.mark.parametrize("name", ["nasa_marine_debris", "vhr10"])
@pytest.mark.parametrize("model_name", ["faster-rcnn", "fcos", "retinanet"])
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, model_name: str, fast_dev_run: bool
Expand Down
16 changes: 8 additions & 8 deletions tests/transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def batch_gray() -> dict[str, Tensor]:
return {
"image": torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float),
"mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
"boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float),
"boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
"labels": torch.tensor([[0, 1]]),
}

Expand All @@ -42,7 +42,7 @@ def batch_rgb() -> dict[str, Tensor]:
dtype=torch.float,
),
"mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
"boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float),
"boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
"labels": torch.tensor([[0, 1]]),
}

Expand All @@ -63,7 +63,7 @@ def batch_multispectral() -> dict[str, Tensor]:
dtype=torch.float,
),
"mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
"boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float),
"boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
"labels": torch.tensor([[0, 1]]),
}

Expand All @@ -79,7 +79,7 @@ def test_augmentation_sequential_gray(batch_gray: dict[str, Tensor]) -> None:
expected = {
"image": torch.tensor([[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]], dtype=torch.float),
"mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
"boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float),
"boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float),
"labels": torch.tensor([[0, 1]]),
}
augs = transforms.AugmentationSequential(
Expand All @@ -102,7 +102,7 @@ def test_augmentation_sequential_rgb(batch_rgb: dict[str, Tensor]) -> None:
dtype=torch.float,
),
"mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
"boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float),
"boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float),
"labels": torch.tensor([[0, 1]]),
}
augs = transforms.AugmentationSequential(
Expand All @@ -129,7 +129,7 @@ def test_augmentation_sequential_multispectral(
dtype=torch.float,
),
"mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
"boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float),
"boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float),
"labels": torch.tensor([[0, 1]]),
}
augs = transforms.AugmentationSequential(
Expand All @@ -156,7 +156,7 @@ def test_augmentation_sequential_image_only(
dtype=torch.float,
),
"mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
"boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float),
"boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
"labels": torch.tensor([[0, 1]]),
}
augs = transforms.AugmentationSequential(
Expand Down Expand Up @@ -188,7 +188,7 @@ def test_sequential_transforms_augmentations(
dtype=torch.float,
),
"mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
"boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float),
"boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float),
"labels": torch.tensor([[0, 1]]),
}
train_transforms = transforms.AugmentationSequential(
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .usavars import USAVarsDataModule
from .utils import MisconfigurationException
from .vaihingen import Vaihingen2DDataModule
from .vhr10 import VHR10DataModule
from .xview import XView2DataModule

__all__ = (
Expand Down Expand Up @@ -79,6 +80,7 @@
"UCMercedDataModule",
"USAVarsDataModule",
"Vaihingen2DDataModule",
"VHR10DataModule",
"XView2DataModule",
# Base classes
"BaseDataModule",
Expand Down
32 changes: 13 additions & 19 deletions torchgeo/datamodules/nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,13 @@

from typing import Any

import kornia.augmentation as K
import torch
from torch import Tensor

from ..datasets import NASAMarineDebris
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule
from .utils import dataset_split


def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]:
"""Custom object detection collate fn to handle variable boxes.
Args:
batch: list of sample dicts return by dataset
Returns:
batch dict output
"""
output: dict[str, Any] = {}
output["image"] = torch.stack([sample["image"] for sample in batch])
output["boxes"] = [sample["boxes"] for sample in batch]
output["labels"] = [torch.tensor([1] * len(sample["boxes"])) for sample in batch]
return output
from .utils import AugPipe, collate_fn_detection, dataset_split


class NASAMarineDebrisDataModule(NonGeoDataModule):
Expand All @@ -35,6 +20,8 @@ class NASAMarineDebrisDataModule(NonGeoDataModule):
.. versionadded:: 0.2
"""

std = torch.tensor(255)

def __init__(
self,
batch_size: int = 64,
Expand All @@ -58,7 +45,14 @@ def __init__(
self.val_split_pct = val_split_pct
self.test_split_pct = test_split_pct

self.collate_fn = collate_fn
self.aug = AugPipe(
AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "boxes"]
),
batch_size,
)

self.collate_fn = collate_fn_detection

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
87 changes: 85 additions & 2 deletions torchgeo/datamodules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

import math
from collections.abc import Iterable
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

import numpy as np
from torch import Generator
import torch
from einops import rearrange
from torch import Generator, Tensor
from torch.nn import Module
from torch.utils.data import Subset, TensorDataset, random_split

from ..datasets import NonGeoDataset
Expand All @@ -19,6 +22,86 @@ class MisconfigurationException(Exception):
"""Exception used to inform users of misuse with Lightning."""


class AugPipe(Module):
"""Pipeline for applying augmentations sequentially on select data keys.
.. versionadded:: 0.6
"""

def __init__(
self, augs: Callable[[dict[str, Any]], dict[str, Any]], batch_size: int
) -> None:
"""Initialize a new AugPipe instance.
Args:
augs: Augmentations to apply.
batch_size: Batch size
"""
super().__init__()
self.augs = augs
self.batch_size = batch_size

def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Apply the augmentation.
Args:
batch: Input batch.
Returns:
Augmented batch.
"""
batch_len = len(batch["image"])
for bs in range(batch_len):
batch_dict = {
"image": batch["image"][bs],
"labels": batch["labels"][bs],
"boxes": batch["boxes"][bs],
}

if "masks" in batch:
batch_dict["masks"] = batch["masks"][bs]

batch_dict = self.augs(batch_dict)

batch["image"][bs] = batch_dict["image"]
batch["labels"][bs] = batch_dict["labels"]
batch["boxes"][bs] = batch_dict["boxes"]

if "masks" in batch:
batch["masks"][bs] = batch_dict["masks"]

# Stack images
batch["image"] = rearrange(batch["image"], "b () c h w -> b c h w")

return batch


def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]:
"""Custom collate fn for object detection and instance segmentation.
Args:
batch: list of sample dicts return by dataset
Returns:
batch dict output
.. versionadded:: 0.6
"""
output: dict[str, Any] = {}
output["image"] = [sample["image"] for sample in batch]
output["boxes"] = [sample["boxes"].float() for sample in batch]
if "labels" in batch[0]:
output["labels"] = [sample["labels"] for sample in batch]
else:
output["labels"] = [
torch.tensor([1] * len(sample["boxes"])) for sample in batch
]

if "masks" in batch[0]:
output["masks"] = [sample["masks"] for sample in batch]
return output


def dataset_split(
dataset: Union[TensorDataset, NonGeoDataset],
val_pct: float,
Expand Down
Loading

0 comments on commit 0f8b0ac

Please sign in to comment.