From f2d311561ddf25edcafa21fc3ead066b895b92b1 Mon Sep 17 00:00:00 2001
From: "Adam J. Stewart" <ajstewart426@gmail.com>
Date: Mon, 26 Dec 2022 10:20:24 -0600
Subject: [PATCH] InriaAerialImageLabelingDataModule: fix predict dimensions
 (#975)

* InriaAerialImageLabelingDataModule: fix predict dimensions

* Record number of patches for reconstruction
---
 tests/conf/{inria.yaml => inria_test.yaml} |  0
 tests/conf/inria_train.yaml                | 20 +++++++
 tests/conf/inria_val.yaml                  | 20 +++++++
 tests/datamodules/test_inria.py            | 67 ----------------------
 tests/trainers/test_segmentation.py        |  8 ++-
 torchgeo/datamodules/inria.py              |  4 +-
 6 files changed, 49 insertions(+), 70 deletions(-)
 rename tests/conf/{inria.yaml => inria_test.yaml} (100%)
 create mode 100644 tests/conf/inria_train.yaml
 create mode 100644 tests/conf/inria_val.yaml
 delete mode 100644 tests/datamodules/test_inria.py

diff --git a/tests/conf/inria.yaml b/tests/conf/inria_test.yaml
similarity index 100%
rename from tests/conf/inria.yaml
rename to tests/conf/inria_test.yaml
diff --git a/tests/conf/inria_train.yaml b/tests/conf/inria_train.yaml
new file mode 100644
index 00000000000..99db7925f27
--- /dev/null
+++ b/tests/conf/inria_train.yaml
@@ -0,0 +1,20 @@
+experiment:
+  task: "inria"
+  module:
+    loss: "ce"
+    model: "unet"
+    backbone: "resnet18"
+    weights: "imagenet"
+    learning_rate: 1e-3
+    learning_rate_schedule_patience: 6
+    in_channels: 3
+    num_classes: 2
+    ignore_index: null
+  datamodule:
+      root: "tests/data/inria"
+      batch_size: 1
+      num_workers: 0
+      val_split_pct: 0.0
+      test_split_pct: 0.0
+      patch_size: 2
+      num_patches_per_tile: 2
diff --git a/tests/conf/inria_val.yaml b/tests/conf/inria_val.yaml
new file mode 100644
index 00000000000..c20f8923439
--- /dev/null
+++ b/tests/conf/inria_val.yaml
@@ -0,0 +1,20 @@
+experiment:
+  task: "inria"
+  module:
+    loss: "ce"
+    model: "unet"
+    backbone: "resnet18"
+    weights: "imagenet"
+    learning_rate: 1e-3
+    learning_rate_schedule_patience: 6
+    in_channels: 3
+    num_classes: 2
+    ignore_index: null
+  datamodule:
+      root: "tests/data/inria"
+      batch_size: 1
+      num_workers: 0
+      val_split_pct: 0.2
+      test_split_pct: 0.0
+      patch_size: 2
+      num_patches_per_tile: 2
diff --git a/tests/datamodules/test_inria.py b/tests/datamodules/test_inria.py
deleted file mode 100644
index e4415db96e5..00000000000
--- a/tests/datamodules/test_inria.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-
-import os
-
-import pytest
-from _pytest.fixtures import SubRequest
-
-from torchgeo.datamodules import InriaAerialImageLabelingDataModule
-
-TEST_DATA_DIR = os.path.join("tests", "data", "inria")
-
-
-class TestInriaAerialImageLabelingDataModule:
-    @pytest.fixture(params=zip([0.2, 0.2, 0.0], [0.2, 0.0, 0.0]))
-    def datamodule(self, request: SubRequest) -> InriaAerialImageLabelingDataModule:
-        val_split_pct, test_split_pct = request.param
-        patch_size = 2  # (2,2)
-        num_patches_per_tile = 2
-        root = TEST_DATA_DIR
-        batch_size = 1
-        num_workers = 0
-        dm = InriaAerialImageLabelingDataModule(
-            root=root,
-            batch_size=batch_size,
-            num_workers=num_workers,
-            val_split_pct=val_split_pct,
-            test_split_pct=test_split_pct,
-            patch_size=patch_size,
-            num_patches_per_tile=num_patches_per_tile,
-        )
-        dm.prepare_data()
-        dm.setup()
-        return dm
-
-    def test_train_dataloader(
-        self, datamodule: InriaAerialImageLabelingDataModule
-    ) -> None:
-        sample = next(iter(datamodule.train_dataloader()))
-        assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2)
-        assert sample["image"].shape[0] == sample["mask"].shape[0] == 2
-        assert sample["image"].shape[1] == 3
-        assert sample["mask"].shape[1] == 1
-
-    def test_val_dataloader(
-        self, datamodule: InriaAerialImageLabelingDataModule
-    ) -> None:
-        sample = next(iter(datamodule.val_dataloader()))
-        if datamodule.val_split_pct > 0.0:
-            assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2)
-            assert sample["image"].shape[0] == sample["mask"].shape[0] == 2
-
-    def test_test_dataloader(
-        self, datamodule: InriaAerialImageLabelingDataModule
-    ) -> None:
-        sample = next(iter(datamodule.test_dataloader()))
-        if datamodule.test_split_pct > 0.0:
-            assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2)
-            assert sample["image"].shape[0] == sample["mask"].shape[0] == 2
-
-    def test_predict_dataloader(
-        self, datamodule: InriaAerialImageLabelingDataModule
-    ) -> None:
-        sample = next(iter(datamodule.predict_dataloader()))
-        assert len(sample["image"].shape) == 5
-        assert sample["image"].shape[-2:] == (2, 2)
-        assert sample["image"].shape[2] == 3
diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py
index dc6be108adf..d30df3d993d 100644
--- a/tests/trainers/test_segmentation.py
+++ b/tests/trainers/test_segmentation.py
@@ -40,7 +40,9 @@ class TestSemanticSegmentationTask:
             ("deepglobelandcover_0", DeepGlobeLandCoverDataModule),
             ("deepglobelandcover_5", DeepGlobeLandCoverDataModule),
             ("etci2021", ETCI2021DataModule),
-            ("inria", InriaAerialImageLabelingDataModule),
+            ("inria_train", InriaAerialImageLabelingDataModule),
+            ("inria_val", InriaAerialImageLabelingDataModule),
+            ("inria_test", InriaAerialImageLabelingDataModule),
             ("landcoverai", LandCoverAIDataModule),
             ("naipchesapeake", NAIPChesapeakeDataModule),
             ("oscd_all", OSCDDataModule),
@@ -80,7 +82,9 @@ def test_trainer(
         trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1)
         trainer.fit(model=model, datamodule=datamodule)
         trainer.test(model=model, datamodule=datamodule)
-        trainer.predict(model=model, dataloaders=datamodule.val_dataloader())
+
+        if hasattr(datamodule, "predict_dataset"):
+            trainer.predict(model=model, datamodule=datamodule)
 
     def test_no_logger(self) -> None:
         conf = OmegaConf.load(os.path.join("tests", "conf", "landcoverai.yaml"))
diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py
index df8b07f0c21..92d26dc87b2 100644
--- a/torchgeo/datamodules/inria.py
+++ b/torchgeo/datamodules/inria.py
@@ -98,7 +98,9 @@ def patch_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
             self.patch_size,
             padding=padding,
         )
-        sample["image"] = rearrange(sample["image"], "() t c h w -> t () c h w")
+        # Needed for reconstruction of patches later
+        sample["num_patches"] = sample["image"].shape[1]
+        sample["image"] = rearrange(sample["image"], "b n c h w -> (b n) c h w")
         return sample
 
     def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: