From 89370c575a3c3c97064f3c63d83214f0e9f9f2b9 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Sun, 16 Oct 2022 11:50:32 +0200 Subject: [PATCH 01/17] fix datamodule --- tests/datamodules/test_vaihingen.py | 2 +- torchgeo/datamodules/vaihingen.py | 118 ++++++++++++++++++++-------- 2 files changed, 86 insertions(+), 34 deletions(-) diff --git a/tests/datamodules/test_vaihingen.py b/tests/datamodules/test_vaihingen.py index 13fd2d52e4c..896faf0000d 100644 --- a/tests/datamodules/test_vaihingen.py +++ b/tests/datamodules/test_vaihingen.py @@ -20,7 +20,7 @@ def datamodule(self, request: SubRequest) -> Vaihingen2DDataModule: val_split_size = request.param dm = Vaihingen2DDataModule( root=root, - batch_size=batch_size, + train_batch_size=batch_size, num_workers=num_workers, val_split_pct=val_split_size, ) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 6017b1cce77..8eca4cd9b93 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -3,11 +3,15 @@ """Vaihingen datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Tuple import matplotlib.pyplot as plt +import kornia.augmentation as K import pytorch_lightning as pl +import torch +from einops import repeat from torch.utils.data import DataLoader, Dataset +from torch.utils.data._utils.collate import default_collate from torchvision.transforms import Compose from ..datasets import Vaihingen2D @@ -24,26 +28,44 @@ class Vaihingen2DDataModule(pl.LightningDataModule): def __init__( self, - batch_size: int = 64, + train_batch_size: int = 32, num_workers: int = 0, + patch_size: Tuple[int, int] = (64, 64), + num_patches_per_tile: int = 32, val_split_pct: float = 0.2, **kwargs: Any, ) -> None: """Initialize a LightningDataModule for Vaihingen2D based DataLoaders. Args: - batch_size: The batch size to use in all created DataLoaders + train_batch_size: The batch size used in the train DataLoader + (val_batch_size == test_batch_size == 1). The effective batch size + will be 'train_batch_size' * 'num_patches_per_tile' num_workers: The number of workers to use in all created DataLoaders val_split_pct: What percentage of the dataset to use as a validation set + patch_size: Size of random patch from image and mask (height, width), should + be a multiple of 32 for most segmentation architectures + num_patches_per_tile: number of random patches per sample **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.Vaihingen2D` + + :: versionchanged:: 0.4 + 'batch_size' is renamed to 'train_batch_size', 'patch_size' and + 'num_patches_per_tile' introduced in order to randomly crop the + variable size images during training """ super().__init__() - self.batch_size = batch_size + self.train_batch_size = train_batch_size self.num_workers = num_workers + self.patch_size = patch_size + self.num_patches_per_tile = num_patches_per_tile self.val_split_pct = val_split_pct self.kwargs = kwargs + self.rcrop = K.AugmentationSequential( + K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True + ) + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset. @@ -64,63 +86,93 @@ def setup(self, stage: Optional[str] = None) -> None: Args: stage: stage to set up + + :: versionchanged:: 0.4 + Add functionality to randomly crop patches from a tile during + training and pad validation and test samples to next multiple of 32 """ - transforms = Compose([self.preprocess]) - dataset = Vaihingen2D(split="train", transforms=transforms, **self.kwargs) + def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]: + images, masks = [], [] + for i in range(self.num_patches_per_tile): + mask = repeat(sample["mask"], "h w -> t h w", t=2).float() + image, mask = self.rcrop(sample["image"], mask) + mask = mask.squeeze()[0] + images.append(image.squeeze()) + masks.append(mask.long()) + sample["image"] = torch.stack(images).squeeze(0) + sample["mask"] = torch.stack(masks).squeeze(0) + return sample + + def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]: + """Pad to next multiple of 32.""" + h, w = sample["image"].shape[1], sample["image"].shape[2] + new_h = int(32 * ((h // 32) + 1)) + new_w = int(32 * ((w // 32) + 1)) + + padto = K.PadTo((new_h, new_w)) + + sample["image"] = padto(sample["image"])[0] + sample["mask"] = padto(sample["mask"].float()).long()[0, 0] + return sample + + train_transforms = Compose([self.preprocess, n_random_crop]) + # for testing and validation we pad all inputs to next larger multiple of 32 + # to avoid issues with upsampling paths in encoder-decoder architectures + test_transforms = Compose([self.preprocess, pad_to]) + + train_dataset = Vaihingen2D( + split="train", transforms=train_transforms, **self.kwargs + ) self.train_dataset: Dataset[Any] self.val_dataset: Dataset[Any] if self.val_split_pct > 0.0: + val_dataset = Vaihingen2D( + split="train", transforms=test_transforms, **self.kwargs + ) self.train_dataset, self.val_dataset, _ = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=0.0 + train_dataset, val_pct=self.val_split_pct, test_pct=0.0 ) + self.val_dataset.dataset = val_dataset else: - self.train_dataset = dataset - self.val_dataset = dataset + self.train_dataset = train_dataset + self.val_dataset = train_dataset self.test_dataset = Vaihingen2D( - split="test", transforms=transforms, **self.kwargs + split="test", transforms=test_transforms, **self.kwargs ) def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. + """Return a DataLoader for training.""" + + def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: + r_batch: Dict[str, Any] = default_collate( # type: ignore[no-untyped-call] + batch + ) + r_batch["image"] = torch.flatten(r_batch["image"], 0, 1) + r_batch["mask"] = torch.flatten(r_batch["mask"], 0, 1) + return r_batch - Returns: - training data loader - """ return DataLoader( self.train_dataset, - batch_size=self.batch_size, + batch_size=self.train_batch_size, num_workers=self.num_workers, + collate_fn=collate_wrapper, shuffle=True, ) def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ + """Return a DataLoader for validation.""" return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, + self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ + """Return a DataLoader for testing.""" return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, + self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: From ebcd1ae6da6e195b9422d50bda1a41c74eb1cd83 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Sun, 16 Oct 2022 19:00:44 +0200 Subject: [PATCH 02/17] requested changes to vaihingen --- torchgeo/datamodules/vaihingen.py | 64 ++++++++++++++++++++++++------- 1 file changed, 50 insertions(+), 14 deletions(-) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 8eca4cd9b93..bdc05a42361 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -3,17 +3,18 @@ """Vaihingen datamodule.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import matplotlib.pyplot as plt import kornia.augmentation as K import pytorch_lightning as pl import torch from einops import repeat -from torch.utils.data import DataLoader, Dataset -from torch.utils.data._utils.collate import default_collate +from torch.utils.data import DataLoader, Dataset, default_collate from torchvision.transforms import Compose +from torchgeo.samplers.utils import _to_tuple + from ..datasets import Vaihingen2D from .utils import dataset_split @@ -30,7 +31,7 @@ def __init__( self, train_batch_size: int = 32, num_workers: int = 0, - patch_size: Tuple[int, int] = (64, 64), + patch_size: Union[Tuple[int, int], int] = (64, 64), num_patches_per_tile: int = 32, val_split_pct: float = 0.2, **kwargs: Any, @@ -49,15 +50,16 @@ def __init__( **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.Vaihingen2D` - :: versionchanged:: 0.4 + .. versionchanged:: 0.4 'batch_size' is renamed to 'train_batch_size', 'patch_size' and 'num_patches_per_tile' introduced in order to randomly crop the variable size images during training """ super().__init__() + self.train_batch_size = train_batch_size self.num_workers = num_workers - self.patch_size = patch_size + self.patch_size = _to_tuple(patch_size) self.num_patches_per_tile = num_patches_per_tile self.val_split_pct = val_split_pct self.kwargs = kwargs @@ -87,12 +89,20 @@ def setup(self, stage: Optional[str] = None) -> None: Args: stage: stage to set up - :: versionchanged:: 0.4 + .. versionchanged:: 0.4 Add functionality to randomly crop patches from a tile during training and pad validation and test samples to next multiple of 32 """ def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]: + """Construct 'num_patches_per_tile' random patches of input tile. + + Args: + sample: contains image and mask tile from dataset + + Returns: + stacked randomly cropped patches from input tile + """ images, masks = [], [] for i in range(self.num_patches_per_tile): mask = repeat(sample["mask"], "h w -> t h w", t=2).float() @@ -105,7 +115,14 @@ def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]: return sample def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]: - """Pad to next multiple of 32.""" + """Pad image and mask to next multiple of 32. + + Args: + sample: contains image and mask sample from dataset + + Returns: + padded image and mask + """ h, w = sample["image"].shape[1], sample["image"].shape[2] new_h = int(32 * ((h // 32) + 1)) new_w = int(32 * ((w // 32) + 1)) @@ -145,12 +162,23 @@ def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]: ) def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training.""" + """Return a DataLoader for training. + + Returns: + training dataloader + """ def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: - r_batch: Dict[str, Any] = default_collate( # type: ignore[no-untyped-call] - batch - ) + """Define collate function to combine patches per tile and batch size. + + Args: + batch: sample batch from dataloader containing image and mask + + Returns: + sample batch where the batch dimension is + 'train_batch_size' * 'num_patches_per_tile' + """ + r_batch: Dict[str, Any] = default_collate(batch) r_batch["image"] = torch.flatten(r_batch["image"], 0, 1) r_batch["mask"] = torch.flatten(r_batch["mask"], 0, 1) return r_batch @@ -164,13 +192,21 @@ def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: ) def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation.""" + """Return a DataLoader for validation. + + Returns: + validation dataloader + """ return DataLoader( self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing.""" + """Return a DataLoader for testing. + + Returns: + test dataloader + """ return DataLoader( self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) From 2c82ca40e4f86ece82fb6b0a670b4493f2c4de61 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Sun, 16 Oct 2022 20:05:57 +0200 Subject: [PATCH 03/17] data loader --- torchgeo/datamodules/vaihingen.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index bdc05a42361..3357165ec9b 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -165,7 +165,7 @@ def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training. Returns: - training dataloader + training data loader """ def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: @@ -195,7 +195,7 @@ def val_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for validation. Returns: - validation dataloader + validation data loader """ return DataLoader( self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False @@ -205,7 +205,7 @@ def test_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for testing. Returns: - test dataloader + test data loader """ return DataLoader( self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False From baf8c05f506b52cee141d7f6273d674402344f2f Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Mon, 17 Oct 2022 11:48:56 +0200 Subject: [PATCH 04/17] fix error an clarity --- torchgeo/datamodules/vaihingen.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 3357165ec9b..50e066a638a 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -9,7 +9,7 @@ import kornia.augmentation as K import pytorch_lightning as pl import torch -from einops import repeat +from einops import rearrange from torch.utils.data import DataLoader, Dataset, default_collate from torchvision.transforms import Compose @@ -65,7 +65,7 @@ def __init__( self.kwargs = kwargs self.rcrop = K.AugmentationSequential( - K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True + K.RandomCrop(patch_size), data_keys=["input", "mask"] ) def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: @@ -105,13 +105,12 @@ def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]: """ images, masks = [], [] for i in range(self.num_patches_per_tile): - mask = repeat(sample["mask"], "h w -> t h w", t=2).float() - image, mask = self.rcrop(sample["image"], mask) - mask = mask.squeeze()[0] - images.append(image.squeeze()) - masks.append(mask.long()) - sample["image"] = torch.stack(images).squeeze(0) - sample["mask"] = torch.stack(masks).squeeze(0) + image, mask = self.rcrop(sample["image"], sample["mask"].float()) + images.append(image.squeeze(0)) + masks.append(mask.squeeze().long()) + + sample["image"] = torch.stack(images) + sample["mask"] = torch.stack(masks) return sample def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]: @@ -179,8 +178,8 @@ def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: 'train_batch_size' * 'num_patches_per_tile' """ r_batch: Dict[str, Any] = default_collate(batch) - r_batch["image"] = torch.flatten(r_batch["image"], 0, 1) - r_batch["mask"] = torch.flatten(r_batch["mask"], 0, 1) + r_batch["image"] = rearrange(r_batch["image"], "b t c h w -> (b t) c h w") + r_batch["mask"] = rearrange(r_batch["mask"], "b t h w -> (b t) h w") return r_batch return DataLoader( From 217b504a528fd71ec99743f5e141daf329b1bb88 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Fri, 28 Oct 2022 19:42:01 +0200 Subject: [PATCH 05/17] fix failing test --- torchgeo/datamodules/vaihingen.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 50e066a638a..77506520aee 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -10,7 +10,8 @@ import pytorch_lightning as pl import torch from einops import rearrange -from torch.utils.data import DataLoader, Dataset, default_collate +from torch.utils.data import DataLoader, Dataset +from torch.utils.data._utils.collate import default_collate from torchvision.transforms import Compose from torchgeo.samplers.utils import _to_tuple @@ -177,7 +178,9 @@ def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: sample batch where the batch dimension is 'train_batch_size' * 'num_patches_per_tile' """ - r_batch: Dict[str, Any] = default_collate(batch) + r_batch: Dict[str, Any] = default_collate( # type: ignore[no-untyped-call] + batch + ) r_batch["image"] = rearrange(r_batch["image"], "b t c h w -> (b t) c h w") r_batch["mask"] = rearrange(r_batch["mask"], "b t h w -> (b t) h w") return r_batch From 7e32a160e2110ad7f8c4cf426c75c9ca7e64c52b Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Sat, 29 Oct 2022 08:33:14 +0200 Subject: [PATCH 06/17] fix failing test crop augmentation --- torchgeo/datamodules/vaihingen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 77506520aee..3d0b8723a21 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -66,7 +66,7 @@ def __init__( self.kwargs = kwargs self.rcrop = K.AugmentationSequential( - K.RandomCrop(patch_size), data_keys=["input", "mask"] + K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True ) def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: From eecc12c12f43cf43ac53707e8e9ac907a206659d Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Sat, 29 Oct 2022 14:09:34 +0200 Subject: [PATCH 07/17] found a bug --- torchgeo/datamodules/vaihingen.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 3d0b8723a21..e762a9e7be8 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -66,7 +66,9 @@ def __init__( self.kwargs = kwargs self.rcrop = K.AugmentationSequential( - K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True + K.RandomCrop(self.patch_size), + data_keys=["input", "mask"], + same_on_batch=True, ) def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: From ad32803faa255c8513e0142a5bc9895df0028336 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Sat, 29 Oct 2022 14:11:22 +0200 Subject: [PATCH 08/17] remove same_batch param --- torchgeo/datamodules/vaihingen.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index e762a9e7be8..9bb589b1477 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -66,9 +66,7 @@ def __init__( self.kwargs = kwargs self.rcrop = K.AugmentationSequential( - K.RandomCrop(self.patch_size), - data_keys=["input", "mask"], - same_on_batch=True, + K.RandomCrop(self.patch_size), data_keys=["input", "mask"] ) def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: From 9c790d0433bf556ca7d7a7ccc681ebcfa2418152 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Thu, 17 Nov 2022 04:43:57 -0800 Subject: [PATCH 09/17] Trying to get minimum tests to pass --- torchgeo/datamodules/vaihingen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 9bb589b1477..7ce8e61ba6b 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -66,7 +66,7 @@ def __init__( self.kwargs = kwargs self.rcrop = K.AugmentationSequential( - K.RandomCrop(self.patch_size), data_keys=["input", "mask"] + K.RandomCrop(self.patch_size, align_corners=False), data_keys=["input", "mask"] ) def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: From a2dfed30af9bff065b3fdf9d081bbed6066fc8c0 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Thu, 17 Nov 2022 04:46:26 -0800 Subject: [PATCH 10/17] Formatting --- torchgeo/datamodules/vaihingen.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 7ce8e61ba6b..faefa8a8297 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -66,7 +66,8 @@ def __init__( self.kwargs = kwargs self.rcrop = K.AugmentationSequential( - K.RandomCrop(self.patch_size, align_corners=False), data_keys=["input", "mask"] + K.RandomCrop(self.patch_size, align_corners=False), + data_keys=["input", "mask"] ) def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: From 7af3f98d74d3a8eacc0bb902ba4a5cb5b2d3e313 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Thu, 17 Nov 2022 04:47:24 -0800 Subject: [PATCH 11/17] Formatting again --- torchgeo/datamodules/vaihingen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index faefa8a8297..e3b0c6dbe7d 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -67,7 +67,7 @@ def __init__( self.rcrop = K.AugmentationSequential( K.RandomCrop(self.patch_size, align_corners=False), - data_keys=["input", "mask"] + data_keys=["input", "mask"], ) def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: From 0258a96d51506df700e447be651855d389ebc925 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 17 Dec 2022 12:23:28 -0600 Subject: [PATCH 12/17] Update torchgeo/datamodules/vaihingen.py --- torchgeo/datamodules/vaihingen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index e3b0c6dbe7d..63d294fbcee 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -14,7 +14,7 @@ from torch.utils.data._utils.collate import default_collate from torchvision.transforms import Compose -from torchgeo.samplers.utils import _to_tuple +from ..samplers.utils import _to_tuple from ..datasets import Vaihingen2D from .utils import dataset_split From ab12bf7f8fbe8162dafb3a772118e91de172756c Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 17 Dec 2022 12:25:50 -0600 Subject: [PATCH 13/17] Sort imports --- torchgeo/datamodules/vaihingen.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 63d294fbcee..c596075ef2d 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -14,9 +14,8 @@ from torch.utils.data._utils.collate import default_collate from torchvision.transforms import Compose -from ..samplers.utils import _to_tuple - from ..datasets import Vaihingen2D +from ..samplers.utils import _to_tuple from .utils import dataset_split From 9d98914fba94f4297b0ef10b1f98b2584caf41c3 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Tue, 20 Dec 2022 17:14:32 +0000 Subject: [PATCH 14/17] Isort, yousort, we all sort --- torchgeo/datamodules/vaihingen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index c596075ef2d..0e3e9a42c26 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -5,8 +5,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union -import matplotlib.pyplot as plt import kornia.augmentation as K +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from einops import rearrange From 88eb51f201e417c1a2a988fbdc7aad836eea9de4 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 29 Dec 2022 22:07:18 -0600 Subject: [PATCH 15/17] Same logic as deepglobe --- conf/deepglobelandcover.yaml | 21 +++ conf/vaihingen2d.yaml | 21 +++ tests/conf/vaihingen2d.yaml | 21 +++ tests/datamodules/test_vaihingen.py | 44 ----- tests/trainers/test_segmentation.py | 2 + torchgeo/datamodules/deepglobelandcover.py | 5 +- torchgeo/datamodules/vaihingen.py | 209 ++++++++------------- torchgeo/datasets/vaihingen.py | 2 +- train.py | 4 + 9 files changed, 147 insertions(+), 182 deletions(-) create mode 100644 conf/deepglobelandcover.yaml create mode 100644 conf/vaihingen2d.yaml create mode 100644 tests/conf/vaihingen2d.yaml delete mode 100644 tests/datamodules/test_vaihingen.py diff --git a/conf/deepglobelandcover.yaml b/conf/deepglobelandcover.yaml new file mode 100644 index 00000000000..7732406da45 --- /dev/null +++ b/conf/deepglobelandcover.yaml @@ -0,0 +1,21 @@ +experiment: + task: "deepglobelandcover" + module: + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 7 + num_filters: 1 + ignore_index: null + datamodule: + root: "data/deepglobelandcover" + num_tiles_per_batch: 16 + num_patches_per_tile: 16 + patch_size: 64 + val_split_pct: 0.5 + num_workers: 0 diff --git a/conf/vaihingen2d.yaml b/conf/vaihingen2d.yaml new file mode 100644 index 00000000000..0f3015faf66 --- /dev/null +++ b/conf/vaihingen2d.yaml @@ -0,0 +1,21 @@ +experiment: + task: "vaihingen2d" + module: + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 7 + num_filters: 1 + ignore_index: null + datamodule: + root: "data/vaihingen" + num_tiles_per_batch: 16 + num_patches_per_tile: 16 + patch_size: 64 + val_split_pct: 0.5 + num_workers: 0 diff --git a/tests/conf/vaihingen2d.yaml b/tests/conf/vaihingen2d.yaml new file mode 100644 index 00000000000..0184b929a0a --- /dev/null +++ b/tests/conf/vaihingen2d.yaml @@ -0,0 +1,21 @@ +experiment: + task: "vaihingen2d" + module: + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 7 + num_filters: 1 + ignore_index: null + datamodule: + root: "tests/data/vaihingen" + num_tiles_per_batch: 1 + num_patches_per_tile: 1 + patch_size: 2 + val_split_pct: 0.5 + num_workers: 0 diff --git a/tests/datamodules/test_vaihingen.py b/tests/datamodules/test_vaihingen.py deleted file mode 100644 index 896faf0000d..00000000000 --- a/tests/datamodules/test_vaihingen.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os - -import matplotlib.pyplot as plt -import pytest -from _pytest.fixtures import SubRequest - -from torchgeo.datamodules import Vaihingen2DDataModule -from torchgeo.datasets import unbind_samples - - -class TestVaihingen2DDataModule: - @pytest.fixture(scope="class", params=[0.0, 0.5]) - def datamodule(self, request: SubRequest) -> Vaihingen2DDataModule: - root = os.path.join("tests", "data", "vaihingen") - batch_size = 1 - num_workers = 0 - val_split_size = request.param - dm = Vaihingen2DDataModule( - root=root, - train_batch_size=batch_size, - num_workers=num_workers, - val_split_pct=val_split_size, - ) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: - next(iter(datamodule.test_dataloader())) - - def test_plot(self, datamodule: Vaihingen2DDataModule) -> None: - batch = next(iter(datamodule.train_dataloader())) - sample = unbind_samples(batch)[0] - datamodule.plot(sample) - plt.close() diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 6c9b1f2ff69..1ffb3931516 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -21,6 +21,7 @@ NAIPChesapeakeDataModule, SEN12MSDataModule, SpaceNet1DataModule, + Vaihingen2DDataModule, ) from torchgeo.datasets import LandCoverAI from torchgeo.trainers import SemanticSegmentationTask @@ -50,6 +51,7 @@ class TestSemanticSegmentationTask: ("sen12ms_s2_all", SEN12MSDataModule), ("sen12ms_s2_reduced", SEN12MSDataModule), ("spacenet1", SpaceNet1DataModule), + ("vaihingen2d", Vaihingen2DDataModule), ], ) def test_trainer( diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index ac71f69bd77..1dbdacb6776 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl +from einops import rearrange from kornia.augmentation import Normalize from torch.utils.data import DataLoader @@ -136,7 +137,7 @@ def on_after_batch_transfer( A batch of data """ # Kornia requires masks to have a channel dimension - batch["mask"] = batch["mask"].unsqueeze(1) + batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") if self.trainer: if self.trainer.training: @@ -145,7 +146,7 @@ def on_after_batch_transfer( batch = self.test_transform(batch) # Torchmetrics does not support masks with a channel dimension - batch["mask"] = batch["mask"].squeeze(1) + batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") return batch diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 0e3e9a42c26..88ddbe73cbb 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -3,19 +3,18 @@ """Vaihingen datamodule.""" -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union -import kornia.augmentation as K import matplotlib.pyplot as plt import pytorch_lightning as pl -import torch from einops import rearrange -from torch.utils.data import DataLoader, Dataset -from torch.utils.data._utils.collate import default_collate -from torchvision.transforms import Compose +from kornia.augmentation import Normalize +from torch.utils.data import DataLoader from ..datasets import Vaihingen2D from ..samplers.utils import _to_tuple +from ..transforms import AugmentationSequential +from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop from .utils import dataset_split @@ -29,171 +28,85 @@ class Vaihingen2DDataModule(pl.LightningDataModule): def __init__( self, - train_batch_size: int = 32, - num_workers: int = 0, - patch_size: Union[Tuple[int, int], int] = (64, 64), - num_patches_per_tile: int = 32, + num_tiles_per_batch: int = 16, + num_patches_per_tile: int = 16, + patch_size: Union[Tuple[int, int], int] = 64, val_split_pct: float = 0.2, + num_workers: int = 0, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for Vaihingen2D based DataLoaders. + """Initialize a new LightningDataModule instance. + + The Vaihingen2D dataset contains images that are too large to pass + directly through a model. Instead, we randomly sample patches from image tiles + during training and chop up image tiles into patch grids during evaluation. + During training, the effective batch size is equal to + ``num_tiles_per_batch`` x ``num_patches_per_tile``. Args: - train_batch_size: The batch size used in the train DataLoader - (val_batch_size == test_batch_size == 1). The effective batch size - will be 'train_batch_size' * 'num_patches_per_tile' - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - patch_size: Size of random patch from image and mask (height, width), should - be a multiple of 32 for most segmentation architectures - num_patches_per_tile: number of random patches per sample + num_tiles_per_batch: The number of image tiles to sample from during + training + num_patches_per_tile: The number of patches to randomly sample from each + image tile during training + patch_size: The size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures + val_split_pct: The percentage of the dataset to use as a validation set + num_workers: The number of workers to use for parallel data loading **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.Vaihingen2D` .. versionchanged:: 0.4 - 'batch_size' is renamed to 'train_batch_size', 'patch_size' and - 'num_patches_per_tile' introduced in order to randomly crop the - variable size images during training + *batch_size* was replaced by *num_tile_per_batch*, *num_patches_per_tile*, + and *patch_size*. """ super().__init__() - self.train_batch_size = train_batch_size - self.num_workers = num_workers - self.patch_size = _to_tuple(patch_size) + self.num_tiles_per_batch = num_tiles_per_batch self.num_patches_per_tile = num_patches_per_tile + self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct + self.num_workers = num_workers self.kwargs = kwargs - self.rcrop = K.AugmentationSequential( - K.RandomCrop(self.patch_size, align_corners=False), - data_keys=["input", "mask"], + self.train_transform = AugmentationSequential( + Normalize(mean=0.0, std=255.0), + _RandomNCrop(self.patch_size, self.num_patches_per_tile), + data_keys=["image", "mask"], + ) + self.test_transform = AugmentationSequential( + Normalize(mean=0.0, std=255.0), + _ExtractTensorPatches(self.patch_size), + data_keys=["image", "mask"], ) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. + """Initialize the main Dataset objects. This method is called once per GPU per run. Args: stage: stage to set up - - .. versionchanged:: 0.4 - Add functionality to randomly crop patches from a tile during - training and pad validation and test samples to next multiple of 32 """ - - def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]: - """Construct 'num_patches_per_tile' random patches of input tile. - - Args: - sample: contains image and mask tile from dataset - - Returns: - stacked randomly cropped patches from input tile - """ - images, masks = [], [] - for i in range(self.num_patches_per_tile): - image, mask = self.rcrop(sample["image"], sample["mask"].float()) - images.append(image.squeeze(0)) - masks.append(mask.squeeze().long()) - - sample["image"] = torch.stack(images) - sample["mask"] = torch.stack(masks) - return sample - - def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]: - """Pad image and mask to next multiple of 32. - - Args: - sample: contains image and mask sample from dataset - - Returns: - padded image and mask - """ - h, w = sample["image"].shape[1], sample["image"].shape[2] - new_h = int(32 * ((h // 32) + 1)) - new_w = int(32 * ((w // 32) + 1)) - - padto = K.PadTo((new_h, new_w)) - - sample["image"] = padto(sample["image"])[0] - sample["mask"] = padto(sample["mask"].float()).long()[0, 0] - return sample - - train_transforms = Compose([self.preprocess, n_random_crop]) - # for testing and validation we pad all inputs to next larger multiple of 32 - # to avoid issues with upsampling paths in encoder-decoder architectures - test_transforms = Compose([self.preprocess, pad_to]) - - train_dataset = Vaihingen2D( - split="train", transforms=train_transforms, **self.kwargs + train_dataset = Vaihingen2D(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + train_dataset, self.val_split_pct ) + self.test_dataset = Vaihingen2D(split="test", **self.kwargs) - self.train_dataset: Dataset[Any] - self.val_dataset: Dataset[Any] - - if self.val_split_pct > 0.0: - val_dataset = Vaihingen2D( - split="train", transforms=test_transforms, **self.kwargs - ) - self.train_dataset, self.val_dataset, _ = dataset_split( - train_dataset, val_pct=self.val_split_pct, test_pct=0.0 - ) - self.val_dataset.dataset = val_dataset - else: - self.train_dataset = train_dataset - self.val_dataset = train_dataset - - self.test_dataset = Vaihingen2D( - split="test", transforms=test_transforms, **self.kwargs - ) - - def train_dataloader(self) -> DataLoader[Any]: + def train_dataloader(self) -> DataLoader[Dict[str, Any]]: """Return a DataLoader for training. Returns: training data loader """ - - def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: - """Define collate function to combine patches per tile and batch size. - - Args: - batch: sample batch from dataloader containing image and mask - - Returns: - sample batch where the batch dimension is - 'train_batch_size' * 'num_patches_per_tile' - """ - r_batch: Dict[str, Any] = default_collate( # type: ignore[no-untyped-call] - batch - ) - r_batch["image"] = rearrange(r_batch["image"], "b t c h w -> (b t) c h w") - r_batch["mask"] = rearrange(r_batch["mask"], "b t h w -> (b t) h w") - return r_batch - return DataLoader( self.train_dataset, - batch_size=self.train_batch_size, + batch_size=self.num_tiles_per_batch, num_workers=self.num_workers, - collate_fn=collate_wrapper, shuffle=True, ) - def val_dataloader(self) -> DataLoader[Any]: + def val_dataloader(self) -> DataLoader[Dict[str, Any]]: """Return a DataLoader for validation. Returns: @@ -203,16 +116,42 @@ def val_dataloader(self) -> DataLoader[Any]: self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) - def test_dataloader(self) -> DataLoader[Any]: + def test_dataloader(self) -> DataLoader[Dict[str, Any]]: """Return a DataLoader for testing. Returns: - test data loader + testing data loader """ return DataLoader( self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) + def on_after_batch_transfer( + self, batch: Dict[str, Any], dataloader_idx: int + ) -> Dict[str, Any]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + # Kornia requires masks to have a channel dimension + batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") + + if self.trainer: + if self.trainer.training: + batch = self.train_transform(batch) + elif self.trainer.validating or self.trainer.testing: + batch = self.test_transform(batch) + + # Torchmetrics does not support masks with a channel dimension + batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") + + return batch + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.Vaihingen2D.plot`. diff --git a/torchgeo/datasets/vaihingen.py b/torchgeo/datasets/vaihingen.py index 4945247f574..8509607f2f0 100644 --- a/torchgeo/datasets/vaihingen.py +++ b/torchgeo/datasets/vaihingen.py @@ -188,7 +188,7 @@ def _load_image(self, index: int) -> Tensor: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW - tensor = tensor.permute((2, 0, 1)) + tensor = tensor.permute((2, 0, 1)).float() return tensor def _load_target(self, index: int) -> Tensor: diff --git a/train.py b/train.py index cb001437686..5bad176717d 100755 --- a/train.py +++ b/train.py @@ -17,6 +17,7 @@ BigEarthNetDataModule, ChesapeakeCVPRDataModule, COWCCountingDataModule, + DeepGlobeLandCoverDataModule, ETCI2021DataModule, EuroSATDataModule, InriaAerialImageLabelingDataModule, @@ -30,6 +31,7 @@ SpaceNet1DataModule, TropicalCycloneDataModule, UCMercedDataModule, + Vaihingen2DDataModule, ) from torchgeo.trainers import ( BYOLTask, @@ -48,6 +50,7 @@ "chesapeake_cvpr": (SemanticSegmentationTask, ChesapeakeCVPRDataModule), "cowc_counting": (RegressionTask, COWCCountingDataModule), "cyclone": (RegressionTask, TropicalCycloneDataModule), + "deepglobelandcover": (SemanticSegmentationTask, DeepGlobeLandCoverDataModule), "eurosat": (ClassificationTask, EuroSATDataModule), "etci2021": (SemanticSegmentationTask, ETCI2021DataModule), "inria": (SemanticSegmentationTask, InriaAerialImageLabelingDataModule), @@ -60,6 +63,7 @@ "so2sat": (ClassificationTask, So2SatDataModule), "spacenet1": (SemanticSegmentationTask, SpaceNet1DataModule), "ucmerced": (ClassificationTask, UCMercedDataModule), + "vaihingen2d": (SemanticSegmentationTask, Vaihingen2DDataModule), } From e291631235e3325ee2597a723323510c76e96c74 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 29 Dec 2022 22:09:56 -0600 Subject: [PATCH 16/17] More-specific types --- torchgeo/datamodules/deepglobelandcover.py | 10 +++++----- torchgeo/datamodules/vaihingen.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 1dbdacb6776..efc85088dcb 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -91,7 +91,7 @@ def setup(self, stage: Optional[str] = None) -> None: ) self.test_dataset = DeepGlobeLandCover(split="test", **self.kwargs) - def train_dataloader(self) -> DataLoader[Dict[str, Any]]: + def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for training. Returns: @@ -104,7 +104,7 @@ def train_dataloader(self) -> DataLoader[Dict[str, Any]]: shuffle=True, ) - def val_dataloader(self) -> DataLoader[Dict[str, Any]]: + def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for validation. Returns: @@ -114,7 +114,7 @@ def val_dataloader(self) -> DataLoader[Dict[str, Any]]: self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) - def test_dataloader(self) -> DataLoader[Dict[str, Any]]: + def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for testing. Returns: @@ -125,8 +125,8 @@ def test_dataloader(self) -> DataLoader[Dict[str, Any]]: ) def on_after_batch_transfer( - self, batch: Dict[str, Any], dataloader_idx: int - ) -> Dict[str, Any]: + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: """Apply augmentations to batch after transferring to GPU. Args: diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 88ddbe73cbb..906bca08254 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -93,7 +93,7 @@ def setup(self, stage: Optional[str] = None) -> None: ) self.test_dataset = Vaihingen2D(split="test", **self.kwargs) - def train_dataloader(self) -> DataLoader[Dict[str, Any]]: + def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for training. Returns: @@ -106,7 +106,7 @@ def train_dataloader(self) -> DataLoader[Dict[str, Any]]: shuffle=True, ) - def val_dataloader(self) -> DataLoader[Dict[str, Any]]: + def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for validation. Returns: @@ -116,7 +116,7 @@ def val_dataloader(self) -> DataLoader[Dict[str, Any]]: self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) - def test_dataloader(self) -> DataLoader[Dict[str, Any]]: + def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for testing. Returns: @@ -127,8 +127,8 @@ def test_dataloader(self) -> DataLoader[Dict[str, Any]]: ) def on_after_batch_transfer( - self, batch: Dict[str, Any], dataloader_idx: int - ) -> Dict[str, Any]: + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: """Apply augmentations to batch after transferring to GPU. Args: From 6ea95d6b45adc87f77dec2209270be33c6779b2f Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 29 Dec 2022 22:11:21 -0600 Subject: [PATCH 17/17] Missing import --- torchgeo/datamodules/deepglobelandcover.py | 1 + torchgeo/datamodules/vaihingen.py | 1 + 2 files changed, 2 insertions(+) diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index efc85088dcb..bd25f0a7c1c 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -9,6 +9,7 @@ import pytorch_lightning as pl from einops import rearrange from kornia.augmentation import Normalize +from torch import Tensor from torch.utils.data import DataLoader from ..datasets import DeepGlobeLandCover diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 906bca08254..aa1499a6742 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -9,6 +9,7 @@ import pytorch_lightning as pl from einops import rearrange from kornia.augmentation import Normalize +from torch import Tensor from torch.utils.data import DataLoader from ..datasets import Vaihingen2D