Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prediction utilities #560

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""torchgeo model inference script."""

import argparse
import os
from typing import Dict, Tuple, Type, cast

import pytorch_lightning as pl
import rasterio as rio
import torch
from kornia.contrib import CombineTensorPatches
from omegaconf import OmegaConf

from torchgeo.datamodules import (
BigEarthNetDataModule,
ChesapeakeCVPRDataModule,
COWCCountingDataModule,
CycloneDataModule,
ETCI2021DataModule,
EuroSATDataModule,
InriaAerialImageLabelingDataModule,
LandCoverAIDataModule,
NAIPChesapeakeDataModule,
OSCDDataModule,
RESISC45DataModule,
SEN12MSDataModule,
So2SatDataModule,
UCMercedDataModule,
)
from torchgeo.trainers import (
BYOLTask,
ClassificationTask,
MultiLabelClassificationTask,
RegressionTask,
SemanticSegmentationTask,
)

TASK_TO_MODULES_MAPPING: Dict[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that this is in two different scripts, we should move it somewhere where both scripts can find it to avoid code duplication. How about a torchgeo/common.py file? I don't want to put it in torchgeo/__init__.py because this will be sourced on every import.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like the idea of torchgeo/common.py. I can add that in a follow up PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to this

str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]]
] = {
"bigearthnet": (MultiLabelClassificationTask, BigEarthNetDataModule),
"byol": (BYOLTask, ChesapeakeCVPRDataModule),
"chesapeake_cvpr": (SemanticSegmentationTask, ChesapeakeCVPRDataModule),
"cowc_counting": (RegressionTask, COWCCountingDataModule),
"cyclone": (RegressionTask, CycloneDataModule),
"eurosat": (ClassificationTask, EuroSATDataModule),
"etci2021": (SemanticSegmentationTask, ETCI2021DataModule),
"inria": (SemanticSegmentationTask, InriaAerialImageLabelingDataModule),
"landcoverai": (SemanticSegmentationTask, LandCoverAIDataModule),
"naipchesapeake": (SemanticSegmentationTask, NAIPChesapeakeDataModule),
"oscd": (SemanticSegmentationTask, OSCDDataModule),
"resisc45": (ClassificationTask, RESISC45DataModule),
"sen12ms": (SemanticSegmentationTask, SEN12MSDataModule),
"so2sat": (ClassificationTask, So2SatDataModule),
"ucmerced": (ClassificationTask, UCMercedDataModule),
}


def write_mask(mask: torch.Tensor, output_dir: str, input_filename: str) -> None:
"""Write mask to specified output directory with same filename as input raster.

Args:
mask (torch.Tensor): mask tensor
output_dir (str): output directory
input_filename (str): path to input raster
"""
output_path = os.path.join(output_dir, os.path.basename(input_filename))
with rio.open(input_filename) as src:
profile = src.profile
profile["count"] = 1
profile["dtype"] = "uint8"
mask = mask.cpu().numpy()
with rio.open(output_path, "w", **profile) as ds:
ds.write(mask)


def main(config_dir: str, predict_on: str, output_dir: str, device: str) -> None:
"""Main inference loop.

Args:
config_dir (str): Path to config-dir to load config and ckpt
predict_on (str): Directory/Dataset to run inference on
output_dir (str): Path to output_directory to save predicted masks
device (str): Choice of device. Must be in [cuda, cpu]

Raises:
ValueError: Raised if task name is not in TASK_TO_MODULES_MAPPING
FileExistsError: Raised if specified output directory contains
files and overwrite=False.
"""
os.makedirs(output_dir, exist_ok=True)

# Load checkpoint and config
conf = OmegaConf.load(os.path.join(config_dir, "experiment_config.yaml"))
ckpt = os.path.join(config_dir, "last.ckpt")
Comment on lines +98 to +99
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these filenames be parameters?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ckpt filename sure. Think the config filename ("experiment_config.yaml") is hard coded.


# Load model
task_name = conf.experiment.task
datamodule: pl.LightningDataModule
task: pl.LightningModule
if task_name not in TASK_TO_MODULES_MAPPING:
raise ValueError(
f"experiment.task={task_name} is not recognized as a valid task"
)
task_class, datamodule_class = TASK_TO_MODULES_MAPPING[task_name]
task = task_class.load_from_checkpoint(ckpt)
task = task.to(device)
task.eval()

# Load datamodule and dataloader
conf.experiment.datamodule["predict_on"] = predict_on
datamodule = datamodule_class(**conf.experiment.datamodule)
datamodule.setup()
dataloader = datamodule.predict_dataloader()

if len(os.listdir(output_dir)) > 0:
if conf.program.overwrite:
print(
f"WARNING! The output directory, {output_dir}, already exists, "
+ "we will overwrite data in it!"
)
else:
raise FileExistsError(
f"The predictions directory, {output_dir}, already exists and isn't "
+ "empty. We don't want to overwrite any existing results, exiting..."
)

for i, batch in enumerate(dataloader):
x = batch["image"].to(device) # (N, B, C, H, W)
assert len(x.shape) in {4, 5}
if len(x.shape) == 5:
masks = []

def tensor_to_int(
tensor_tuple: Tuple[torch.Tensor, ...]
) -> Tuple[int, ...]:
"""Convert tuple of tensors to tuple of ints."""
return tuple(int(i.item()) for i in tensor_tuple)

original_shape = cast(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These shouldn't require a cast if everything is typed correctly.

Tuple[int, int], tensor_to_int(batch["original_shape"])
)
patch_shape = cast(Tuple[int, int], tensor_to_int(batch["patch_shape"]))
padding = cast(Tuple[int, int], tensor_to_int(batch["padding"]))
patch_combine = CombineTensorPatches(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you tell me more about how this function works? What happens if your patches have overlap? Is this sufficient to close #30 or do we need something more powerful/generic?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you tell me more about how this function works?

I've written a tutorial describing how extract and combine patches work here

What happens if your patches have overlap?

Currently you can extract patches with overlap (via the stride parameter) but you can't merge them together. This is because CombineTensorPatches currently only supports stride=window_size as seen here

Is this sufficient to close #30 or do we need something more powerful/generic?

Based on my understanding of the scope of #30, I would say no. If the goal is to just enable users to extract patches and stitch them together, kornia's ExtractTensorPatches and CombineTensorPatches are sufficient. Once CombineTensorPatches supports stride!=window_size, we will be able to handle patches with overlap.

But for the alternate stitching techniques (like label averaging) mentioned in the paper referenced in #30, we might need something more powerful as CombineTensorPatches doesn't support this.

original_size=original_shape, window_size=patch_shape, unpadding=padding
)

for tile in x:
mask = task(tile)
mask = mask.argmax(dim=1)
masks.append(mask)

masks_arr = torch.stack(masks, dim=0)
masks_arr = masks_arr.unsqueeze(0)
masks_combined = patch_combine(masks_arr)[0]
filename = datamodule.predict_dataset.files[i]["image"]
write_mask(masks_combined, output_dir, filename)
else:
mask = task(x)
mask = mask.argmax(dim=1)
filename = datamodule.predict_dataset.files[i]["image"]
write_mask(mask, output_dir, filename)


if __name__ == "__main__":
# Taken from https://github.com/pangeo-data/cog-best-practices
_rasterio_best_practices = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can also be moved to torchgeo/common.py

"GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR",
"AWS_NO_SIGN_REQUEST": "YES",
"GDAL_MAX_RAW_BLOCK_CACHE_SIZE": "200000000",
"GDAL_SWATH_SIZE": "200000000",
"VSI_CURL_CACHE_SIZE": "200000000",
}
os.environ.update(_rasterio_best_practices)

parser = argparse.ArgumentParser()
parser.add_argument(
"--config-dir",
type=str,
required=True,
help="Path to config-dir to load config and ckpt",
)

parser.add_argument(
"--predict_on",
type=str,
required=True,
help="Directory/Dataset to run inference on",
)

parser.add_argument(
"--output-dir",
type=str,
required=True,
help="Path to output_directory to save predicted mask geotiffs",
)

parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we should let PyTorch Lightning handle?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps. LightningModules shouldn't need this but I'll need to verify. Will get back to this.

Copy link
Collaborator Author

@ashnair1 ashnair1 Jun 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While LightningModules are aware of which device they're on, the models (UNet, DeepLab etc) are not. Since we can't forward device info, this will be required.

args = parser.parse_args()
main(args.config_dir, args.predict_on, args.output_dir, args.device)
3 changes: 2 additions & 1 deletion tests/datamodules/test_inria.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from torchgeo.datamodules import InriaAerialImageLabelingDataModule

TEST_DATA_DIR = os.path.join("tests", "data", "inria")
PREDICT_DATA_DIR = os.path.join(TEST_DATA_DIR, "AerialImageDataset/test/images")


class TestInriaAerialImageLabelingDataModule:
@pytest.fixture(
params=zip([0.2, 0.2, 0.0], [0.2, 0.0, 0.0], ["test", "test", "test"])
params=zip([0.2, 0.2, 0.0], [0.2, 0.0, 0.0], ["test", "test", PREDICT_DATA_DIR])
)
def datamodule(self, request: SubRequest) -> InriaAerialImageLabelingDataModule:
val_split_pct, test_split_pct, predict_on = request.param
Expand Down
28 changes: 28 additions & 0 deletions tests/datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
import numpy as np
import pytest
import torch
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS

import torchgeo.datasets.utils
from torchgeo.datasets.utils import (
BoundingBox,
PredictDataset,
concat_samples,
disambiguate_timestamp,
download_and_extract_archive,
Expand Down Expand Up @@ -582,3 +584,29 @@ def test_percentile_normalization() -> None:
img = percentile_normalization(img, 2, 98)
assert img.min() == 0
assert img.max() == 1


class TestPredictDataset:
@pytest.fixture(
params=zip(
[None, torch.nn.Identity(), None], # type: ignore[no-untyped-call]
calebrob6 marked this conversation as resolved.
Show resolved Hide resolved
[(2, 2), (8, 8), (16, 16)],
)
)
def dataset(self, request: SubRequest) -> PredictDataset:
root = os.path.join(
"tests", "data", "inria", "AerialImageDataset", "test", "images"
)
transforms, patch_size = request.param
return PredictDataset(root, patch_size=patch_size, transforms=transforms)

def test_getitem(self, dataset: PredictDataset) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert x["image"].ndim == 5
assert len(x["original_shape"]) == len(x["patch_shape"]) == 2
assert len(x["padding"]) == 4

def test_len(self, dataset: PredictDataset) -> None:
assert len(dataset) == 5
16 changes: 11 additions & 5 deletions torchgeo/datamodules/inria.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""InriaAerialImageLabeling datamodule."""

import os
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import kornia.augmentation as K
Expand All @@ -14,7 +15,7 @@
from torch.utils.data import DataLoader, Dataset
from torch.utils.data._utils.collate import default_collate

from ..datasets import InriaAerialImageLabeling
from ..datasets import InriaAerialImageLabeling, PredictDataset
from ..samplers.utils import _to_tuple
from .utils import dataset_split

Expand Down Expand Up @@ -167,10 +168,15 @@ def setup(self, stage: Optional[str] = None) -> None:
self.val_dataset = train_dataset
self.test_dataset = train_dataset

assert self.predict_on == "test"
self.predict_dataset = InriaAerialImageLabeling(
self.root_dir, self.predict_on, transforms=test_transforms
)
if os.path.isdir(self.predict_on):
self.predict_dataset = PredictDataset(
self.predict_on, patch_size=self.patch_size, transforms=self.preprocess
)
else:
assert self.predict_on == "test"
self.predict_dataset = InriaAerialImageLabeling( # type: ignore[assignment]
self.root_dir, self.predict_on, transforms=test_transforms
)

def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training."""
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
from .usavars import USAVars
from .utils import (
BoundingBox,
PredictDataset,
concat_samples,
merge_samples,
stack_samples,
Expand Down Expand Up @@ -202,6 +203,7 @@
"VisionDataset",
# Utilities
"BoundingBox",
"PredictDataset",
"concat_samples",
"merge_samples",
"stack_samples",
Expand Down
Loading