diff --git a/conf/inria.yaml b/conf/inria.yaml new file mode 100644 index 00000000000..c7e647503f8 --- /dev/null +++ b/conf/inria.yaml @@ -0,0 +1,30 @@ +program: + overwrite: True + + +trainer: + gpus: 1 + min_epochs: 5 + max_epochs: 100 + benchmark: True + log_every_n_steps: 2 + +experiment: + task: "inria" + name: "inria_test" + module: + loss: "ce" + segmentation_model: "unet" + encoder_name: "resnet18" + encoder_weights: "imagenet" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 3 + num_classes: 2 + ignore_zeros: True # class 0 not used for scoring + datamodule: + root_dir: "data/inria" + batch_size: 2 + num_workers: 32 + patch_size: 512 + num_patches_per_tile: 4 diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index ad47f9cd648..8da3645f4e7 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -44,6 +44,11 @@ FAIR1M (Fine-grAined object recognItion in high-Resolution imagery) .. autoclass:: FAIR1MDataModule +Inria Aerial Image Labeling +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: InriaAerialImageLabelingDataModule + LandCover.ai (Land Cover from Aerial Imagery) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/conf.py b/docs/conf.py index b346b8e3914..558bf9b710e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -104,6 +104,7 @@ # sphinx.ext.intersphinx intersphinx_mapping = { + "kornia": ("https://kornia.readthedocs.io/en/stable/", None), "matplotlib": ("https://matplotlib.org/stable/", None), "numpy": ("https://numpy.org/doc/stable/", None), "python": ("https://docs.python.org/3", None), diff --git a/setup.cfg b/setup.cfg index 349186e4905..ce86d6348d1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,8 +31,8 @@ install_requires = einops # fiona 1.5+ required for fiona.transform module fiona>=1.5 - # kornia 0.5.11+ required for kornia.augmentation.PadTo - kornia>=0.5.11 + # kornia 0.6.4+ required for kornia.contrib.compute_padding + kornia>=0.6.4 matplotlib numpy # omegaconf 2.1+ required for to_object method diff --git a/tests/conf/inria.yaml b/tests/conf/inria.yaml new file mode 100644 index 00000000000..5bd48eb4594 --- /dev/null +++ b/tests/conf/inria.yaml @@ -0,0 +1,20 @@ +experiment: + task: "inria" + module: + loss: "ce" + segmentation_model: "unet" + encoder_name: "resnet18" + encoder_weights: "imagenet" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 3 + num_classes: 2 + ignore_zeros: True # class 0 not used for scoring + datamodule: + root_dir: "tests/data/inria" + batch_size: 1 + num_workers: 0 + val_split_pct: 0.2 + test_split_pct: 0.2 + patch_size: 2 + num_patches_per_tile: 2 diff --git a/tests/data/inria/AerialImageDataset/test/images/austin10.tif b/tests/data/inria/AerialImageDataset/test/images/austin10.tif index 0b92a71db98..d77ca4e7fa4 100644 Binary files a/tests/data/inria/AerialImageDataset/test/images/austin10.tif and b/tests/data/inria/AerialImageDataset/test/images/austin10.tif differ diff --git a/tests/data/inria/AerialImageDataset/test/images/austin11.tif b/tests/data/inria/AerialImageDataset/test/images/austin11.tif index d7e1c226e2f..0042958d9b7 100644 Binary files a/tests/data/inria/AerialImageDataset/test/images/austin11.tif and b/tests/data/inria/AerialImageDataset/test/images/austin11.tif differ diff --git a/tests/data/inria/AerialImageDataset/test/images/austin12.tif b/tests/data/inria/AerialImageDataset/test/images/austin12.tif new file mode 100644 index 00000000000..c7c12752406 Binary files /dev/null and b/tests/data/inria/AerialImageDataset/test/images/austin12.tif differ diff --git a/tests/data/inria/AerialImageDataset/test/images/austin13.tif b/tests/data/inria/AerialImageDataset/test/images/austin13.tif new file mode 100644 index 00000000000..029444bc99f Binary files /dev/null and b/tests/data/inria/AerialImageDataset/test/images/austin13.tif differ diff --git a/tests/data/inria/AerialImageDataset/test/images/austin14.tif b/tests/data/inria/AerialImageDataset/test/images/austin14.tif new file mode 100644 index 00000000000..6a84ce1c5c9 Binary files /dev/null and b/tests/data/inria/AerialImageDataset/test/images/austin14.tif differ diff --git a/tests/data/inria/AerialImageDataset/train/gt/austin1.tif b/tests/data/inria/AerialImageDataset/train/gt/austin1.tif index 779eb9c5560..9bf2873fff4 100644 Binary files a/tests/data/inria/AerialImageDataset/train/gt/austin1.tif and b/tests/data/inria/AerialImageDataset/train/gt/austin1.tif differ diff --git a/tests/data/inria/AerialImageDataset/train/gt/austin2.tif b/tests/data/inria/AerialImageDataset/train/gt/austin2.tif index bd75c445390..b06a9363da4 100644 Binary files a/tests/data/inria/AerialImageDataset/train/gt/austin2.tif and b/tests/data/inria/AerialImageDataset/train/gt/austin2.tif differ diff --git a/tests/data/inria/AerialImageDataset/train/gt/austin3.tif b/tests/data/inria/AerialImageDataset/train/gt/austin3.tif new file mode 100644 index 00000000000..2d134842907 Binary files /dev/null and b/tests/data/inria/AerialImageDataset/train/gt/austin3.tif differ diff --git a/tests/data/inria/AerialImageDataset/train/gt/austin4.tif b/tests/data/inria/AerialImageDataset/train/gt/austin4.tif new file mode 100644 index 00000000000..21cb217cef3 Binary files /dev/null and b/tests/data/inria/AerialImageDataset/train/gt/austin4.tif differ diff --git a/tests/data/inria/AerialImageDataset/train/gt/austin5.tif b/tests/data/inria/AerialImageDataset/train/gt/austin5.tif new file mode 100644 index 00000000000..3a819af9b82 Binary files /dev/null and b/tests/data/inria/AerialImageDataset/train/gt/austin5.tif differ diff --git a/tests/data/inria/AerialImageDataset/train/images/austin1.tif b/tests/data/inria/AerialImageDataset/train/images/austin1.tif index 8cbd1eba8f5..cabf7459159 100644 Binary files a/tests/data/inria/AerialImageDataset/train/images/austin1.tif and b/tests/data/inria/AerialImageDataset/train/images/austin1.tif differ diff --git a/tests/data/inria/AerialImageDataset/train/images/austin2.tif b/tests/data/inria/AerialImageDataset/train/images/austin2.tif index 466beacf67c..df55cf5b7bc 100644 Binary files a/tests/data/inria/AerialImageDataset/train/images/austin2.tif and b/tests/data/inria/AerialImageDataset/train/images/austin2.tif differ diff --git a/tests/data/inria/AerialImageDataset/train/images/austin3.tif b/tests/data/inria/AerialImageDataset/train/images/austin3.tif new file mode 100644 index 00000000000..c99ea6c2637 Binary files /dev/null and b/tests/data/inria/AerialImageDataset/train/images/austin3.tif differ diff --git a/tests/data/inria/AerialImageDataset/train/images/austin4.tif b/tests/data/inria/AerialImageDataset/train/images/austin4.tif new file mode 100644 index 00000000000..33dc4eefa32 Binary files /dev/null and b/tests/data/inria/AerialImageDataset/train/images/austin4.tif differ diff --git a/tests/data/inria/AerialImageDataset/train/images/austin5.tif b/tests/data/inria/AerialImageDataset/train/images/austin5.tif new file mode 100644 index 00000000000..2e973747a15 Binary files /dev/null and b/tests/data/inria/AerialImageDataset/train/images/austin5.tif differ diff --git a/tests/data/inria/NEW2-AerialImageDataset.zip b/tests/data/inria/NEW2-AerialImageDataset.zip index 8a4ff559c91..153d3b39959 100644 Binary files a/tests/data/inria/NEW2-AerialImageDataset.zip and b/tests/data/inria/NEW2-AerialImageDataset.zip differ diff --git a/tests/data/inria/data.py b/tests/data/inria/data.py index f11c34cf9a6..4cb68c9d587 100644 --- a/tests/data/inria/data.py +++ b/tests/data/inria/data.py @@ -43,7 +43,7 @@ def generate_test_data(root: str, n_samples: int = 2) -> str: str: md5 hash of created archive """ dtype = np.dtype("uint8") - size = (64, 64) + size = (8, 8) driver = "GTiff" transform = Affine(0.3, 0.0, 616500.0, 0.0, -0.3, 3345000.0) @@ -83,9 +83,9 @@ def generate_test_data(root: str, n_samples: int = 2) -> str: archive_path, "zip", root_dir=root, base_dir="AerialImageDataset" ) shutil.rmtree(folder_path) - return calculate_md5(archive_path + ".zip") + return calculate_md5(f"{archive_path}.zip") if __name__ == "__main__": - md5_hash = generate_test_data(os.getcwd(), 2) + md5_hash = generate_test_data(os.getcwd(), 5) print(md5_hash) diff --git a/tests/datamodules/test_inria.py b/tests/datamodules/test_inria.py new file mode 100644 index 00000000000..8ad1d33d2d6 --- /dev/null +++ b/tests/datamodules/test_inria.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +from _pytest.fixtures import SubRequest + +from torchgeo.datamodules import InriaAerialImageLabelingDataModule + +TEST_DATA_DIR = os.path.join("tests", "data", "inria") + + +class TestInriaAerialImageLabelingDataModule: + @pytest.fixture( + params=zip([0.2, 0.2, 0.0], [0.2, 0.0, 0.0], ["test", "test", "test"]) + ) + def datamodule(self, request: SubRequest) -> InriaAerialImageLabelingDataModule: + val_split_pct, test_split_pct, predict_on = request.param + patch_size = 2 # (2,2) + num_patches_per_tile = 2 + root = TEST_DATA_DIR + batch_size = 1 + num_workers = 0 + dm = InriaAerialImageLabelingDataModule( + root, + batch_size, + num_workers, + val_split_pct, + test_split_pct, + patch_size, + num_patches_per_tile, + predict_on=predict_on, + ) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader( + self, datamodule: InriaAerialImageLabelingDataModule + ) -> None: + sample = next(iter(datamodule.train_dataloader())) + assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) + assert sample["image"].shape[0] == sample["mask"].shape[0] == 2 + assert sample["image"].shape[1] == 3 + assert sample["mask"].shape[1] == 1 + + def test_val_dataloader( + self, datamodule: InriaAerialImageLabelingDataModule + ) -> None: + sample = next(iter(datamodule.val_dataloader())) + if datamodule.val_split_pct > 0.0: + assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) + assert sample["image"].shape[0] == sample["mask"].shape[0] == 2 + + def test_test_dataloader( + self, datamodule: InriaAerialImageLabelingDataModule + ) -> None: + sample = next(iter(datamodule.test_dataloader())) + if datamodule.test_split_pct > 0.0: + assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) + assert sample["image"].shape[0] == sample["mask"].shape[0] == 2 + + def test_predict_dataloader( + self, datamodule: InriaAerialImageLabelingDataModule + ) -> None: + sample = next(iter(datamodule.predict_dataloader())) + assert len(sample["image"].shape) == 5 + assert sample["image"].shape[-2:] == (2, 2) + assert sample["image"].shape[2] == 3 diff --git a/tests/datasets/test_inria.py b/tests/datasets/test_inria.py index 676b22e3c79..1a769e120d5 100644 --- a/tests/datasets/test_inria.py +++ b/tests/datasets/test_inria.py @@ -21,7 +21,7 @@ def dataset( ) -> InriaAerialImageLabeling: root = os.path.join("tests", "data", "inria") - test_md5 = "f23caf363389ef59de55fad11197c161" + test_md5 = "478688944e4797c097d9387fd0b3f038" monkeypatch.setattr(InriaAerialImageLabeling, "md5", test_md5) transforms = nn.Identity() # type: ignore[no-untyped-call] return InriaAerialImageLabeling( @@ -39,7 +39,7 @@ def test_getitem(self, dataset: InriaAerialImageLabeling) -> None: assert x["image"].ndim == 3 def test_len(self, dataset: InriaAerialImageLabeling) -> None: - assert len(dataset) == 2 + assert len(dataset) == 5 def test_already_downloaded(self, dataset: InriaAerialImageLabeling) -> None: InriaAerialImageLabeling(root=dataset.root) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 9332658dc89..20c87a5aa6d 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -14,6 +14,7 @@ from torchgeo.datamodules import ( ChesapeakeCVPRDataModule, ETCI2021DataModule, + InriaAerialImageLabelingDataModule, LandCoverAIDataModule, NAIPChesapeakeDataModule, OSCDDataModule, @@ -34,6 +35,7 @@ class TestSemanticSegmentationTask: [ ("chesapeake_cvpr_5", ChesapeakeCVPRDataModule), ("etci2021", ETCI2021DataModule), + ("inria", InriaAerialImageLabelingDataModule), ("landcoverai", LandCoverAIDataModule), ("naipchesapeake", NAIPChesapeakeDataModule), ("oscd_all", OSCDDataModule), diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index e09fe0ab378..64f4c2d8d95 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -10,6 +10,7 @@ from .etci2021 import ETCI2021DataModule from .eurosat import EuroSATDataModule from .fair1m import FAIR1MDataModule +from .inria import InriaAerialImageLabelingDataModule from .landcoverai import LandCoverAIDataModule from .loveda import LoveDADataModule from .naip import NAIPChesapeakeDataModule @@ -33,6 +34,7 @@ "ETCI2021DataModule", "EuroSATDataModule", "FAIR1MDataModule", + "InriaAerialImageLabelingDataModule", "LandCoverAIDataModule", "LoveDADataModule", "NASAMarineDebrisDataModule", diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py new file mode 100644 index 00000000000..88f502c59bd --- /dev/null +++ b/torchgeo/datamodules/inria.py @@ -0,0 +1,242 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""InriaAerialImageLabeling datamodule.""" + +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +import kornia.augmentation as K +import pytorch_lightning as pl +import torch +import torchvision.transforms as T +from einops import rearrange +from kornia.contrib import compute_padding, extract_tensor_patches +from torch.utils.data import DataLoader, Dataset +from torch.utils.data._utils.collate import default_collate + +from ..datasets import InriaAerialImageLabeling +from ..samplers.utils import _to_tuple +from .utils import dataset_split + +DEFAULT_AUGS = K.AugmentationSequential( + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + data_keys=["input", "mask"], +) + + +def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: + """Flatten wrapper.""" + r_batch: Dict[str, Any] = default_collate(batch) # type: ignore[no-untyped-call] + r_batch["image"] = torch.flatten(r_batch["image"], 0, 1) + if "mask" in r_batch: + r_batch["mask"] = torch.flatten(r_batch["mask"], 0, 1) + + return r_batch + + +class InriaAerialImageLabelingDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the InriaAerialImageLabeling dataset. + + Uses the train/test splits from the dataset and further splits + the train split into train/val splits. + + .. versionadded:: 0.3 + """ + + h, w = 5000, 5000 + + def __init__( + self, + root_dir: str, + batch_size: int = 32, + num_workers: int = 0, + val_split_pct: float = 0.1, + test_split_pct: float = 0.1, + patch_size: Union[int, Tuple[int, int]] = 512, + num_patches_per_tile: int = 32, + augmentations: K.AugmentationSequential = DEFAULT_AUGS, + predict_on: str = "test", + ) -> None: + """Initialize a LightningDataModule for InriaAerialImageLabeling based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the InriaAerialImageLabeling + Dataset classes + batch_size: The batch size used in the train DataLoader + (val_batch_size == test_batch_size == 1) + num_workers: The number of workers to use in all created DataLoaders + val_split_pct: What percentage of the dataset to use as a validation set + test_split_pct: What percentage of the dataset to use as a test set + patch_size: Size of random patch from image and mask (height, width) + num_patches_per_tile: Number of random patches per sample + augmentations: Default augmentations applied + predict_on: Directory/Dataset of images to run inference on + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.val_split_pct = val_split_pct + self.test_split_pct = test_split_pct + self.patch_size = cast(Tuple[int, int], _to_tuple(patch_size)) + self.num_patches_per_tile = num_patches_per_tile + self.augmentations = augmentations + self.predict_on = predict_on + self.random_crop = K.AugmentationSequential( + K.RandomCrop(self.patch_size, p=1.0, keepdim=False), + data_keys=["input", "mask"], + ) + + def patch_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Extract patches from single sample.""" + assert sample["image"].ndim == 3 + _, h, w = sample["image"].shape + + padding = compute_padding((h, w), self.patch_size) + sample["original_shape"] = (h, w) + sample["patch_shape"] = self.patch_size + sample["padding"] = padding + sample["image"] = extract_tensor_patches( + sample["image"].unsqueeze(0), + self.patch_size, + self.patch_size, + padding=padding, + ) + sample["image"] = rearrange(sample["image"], "() t c h w -> t () c h w") + return sample + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset.""" + # RGB is int32 so divide by 255 + sample["image"] = sample["image"] / 255.0 + sample["image"] = torch.clip(sample["image"], min=0.0, max=1.0) + + if "mask" in sample: + sample["mask"] = rearrange(sample["mask"], "h w -> () h w") + + return sample + + def n_random_crop(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Get n random crops.""" + images, masks = [], [] + for _ in range(self.num_patches_per_tile): + image, mask = sample["image"], sample["mask"] + # RandomCrop needs image and mask to be in float + mask = mask.to(torch.float) + image, mask = self.random_crop(image, mask) + images.append(image.squeeze()) + masks.append(mask.squeeze(0).long()) + sample["image"] = torch.stack(images) # (t,c,h,w) + sample["mask"] = torch.stack(masks) # (t, 1, h, w) + return sample + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + """ + train_transforms = T.Compose([self.preprocess, self.n_random_crop]) + test_transforms = T.Compose([self.preprocess, self.patch_sample]) + + train_dataset = InriaAerialImageLabeling( + self.root_dir, split="train", transforms=train_transforms + ) + + self.train_dataset: Dataset[Any] + self.val_dataset: Dataset[Any] + self.test_dataset: Dataset[Any] + + if self.val_split_pct > 0.0: + if self.test_split_pct > 0.0: + self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( + train_dataset, + val_pct=self.val_split_pct, + test_pct=self.test_split_pct, + ) + else: + self.train_dataset, self.val_dataset = dataset_split( + train_dataset, val_pct=self.val_split_pct + ) + self.test_dataset = self.val_dataset + else: + self.train_dataset = train_dataset + self.val_dataset = train_dataset + self.test_dataset = train_dataset + + assert self.predict_on == "test" + self.predict_dataset = InriaAerialImageLabeling( + self.root_dir, self.predict_on, transforms=test_transforms + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training.""" + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=collate_wrapper, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation.""" + return DataLoader( + self.val_dataset, + batch_size=1, + num_workers=self.num_workers, + collate_fn=collate_wrapper, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing.""" + return DataLoader( + self.test_dataset, + batch_size=1, + num_workers=self.num_workers, + collate_fn=collate_wrapper, + shuffle=False, + ) + + def predict_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for prediction.""" + return DataLoader( + self.predict_dataset, + batch_size=1, + num_workers=self.num_workers, + collate_fn=collate_wrapper, + shuffle=False, + ) + + def on_after_batch_transfer( + self, batch: Dict[str, Any], dataloader_idx: int + ) -> Dict[str, Any]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch (dict): A batch of data that needs to be altered or augmented. + dataloader_idx (int): The index of the dataloader to which the batch + belongs. + + Returns: + dict: A batch of data + """ + # Training + if ( + hasattr(self, "trainer") + and self.trainer is not None + and hasattr(self.trainer, "training") + and self.trainer.training + and self.augmentations is not None + ): + batch["mask"] = batch["mask"].to(torch.float) + batch["image"], batch["mask"] = self.augmentations( + batch["image"], batch["mask"] + ) + batch["mask"] = batch["mask"].to(torch.long) + + # Validation + if "mask" in batch: + batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") + return batch diff --git a/torchgeo/datasets/inria.py b/torchgeo/datasets/inria.py index 53d61129708..813e61a2d8a 100644 --- a/torchgeo/datasets/inria.py +++ b/torchgeo/datasets/inria.py @@ -14,12 +14,8 @@ from matplotlib.figure import Figure from torch import Tensor -from torchgeo.datasets.geo import VisionDataset -from torchgeo.datasets.utils import ( - check_integrity, - extract_archive, - percentile_normalization, -) +from .geo import VisionDataset +from .utils import check_integrity, extract_archive, percentile_normalization class InriaAerialImageLabeling(VisionDataset): @@ -103,10 +99,10 @@ def _load_files(self, root: str) -> List[Dict[str, str]]: labels = sorted(labels) for img, lbl in zip(images, labels): - files.append({"image_path": img, "label_path": lbl}) + files.append({"image": img, "label": lbl}) else: for img in images: - files.append({"image_path": img}) + files.append({"image": img}) return files @@ -157,11 +153,10 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: data and label at that index """ files = self.files[index] - sample = {} - img = self._load_image(files["image_path"]) - sample["image"] = img - if files.get("label_path"): - mask = self._load_target(files["label_path"]) + img = self._load_image(files["image"]) + sample = {"image": img} + if files.get("label"): + mask = self._load_target(files["label"]) sample["mask"] = mask if self.transforms is not None: diff --git a/train.py b/train.py index 101ecc6e9f1..8aab0b81463 100755 --- a/train.py +++ b/train.py @@ -20,6 +20,7 @@ CycloneDataModule, ETCI2021DataModule, EuroSATDataModule, + InriaAerialImageLabelingDataModule, LandCoverAIDataModule, NAIPChesapeakeDataModule, OSCDDataModule, @@ -46,6 +47,7 @@ "cyclone": (RegressionTask, CycloneDataModule), "eurosat": (ClassificationTask, EuroSATDataModule), "etci2021": (SemanticSegmentationTask, ETCI2021DataModule), + "inria": (SemanticSegmentationTask, InriaAerialImageLabelingDataModule), "landcoverai": (SemanticSegmentationTask, LandCoverAIDataModule), "naipchesapeake": (SemanticSegmentationTask, NAIPChesapeakeDataModule), "oscd": (SemanticSegmentationTask, OSCDDataModule),