-
Notifications
You must be signed in to change notification settings - Fork 388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add prediction utilities #560
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
"""torchgeo model inference script.""" | ||
|
||
import argparse | ||
import os | ||
from typing import Dict, Tuple, Type, cast | ||
|
||
import pytorch_lightning as pl | ||
import rasterio as rio | ||
import torch | ||
from kornia.contrib import CombineTensorPatches | ||
from omegaconf import OmegaConf | ||
|
||
from torchgeo.datamodules import ( | ||
BigEarthNetDataModule, | ||
ChesapeakeCVPRDataModule, | ||
COWCCountingDataModule, | ||
CycloneDataModule, | ||
ETCI2021DataModule, | ||
EuroSATDataModule, | ||
InriaAerialImageLabelingDataModule, | ||
LandCoverAIDataModule, | ||
NAIPChesapeakeDataModule, | ||
OSCDDataModule, | ||
RESISC45DataModule, | ||
SEN12MSDataModule, | ||
So2SatDataModule, | ||
UCMercedDataModule, | ||
) | ||
from torchgeo.trainers import ( | ||
BYOLTask, | ||
ClassificationTask, | ||
MultiLabelClassificationTask, | ||
RegressionTask, | ||
SemanticSegmentationTask, | ||
) | ||
|
||
TASK_TO_MODULES_MAPPING: Dict[ | ||
str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]] | ||
] = { | ||
"bigearthnet": (MultiLabelClassificationTask, BigEarthNetDataModule), | ||
"byol": (BYOLTask, ChesapeakeCVPRDataModule), | ||
"chesapeake_cvpr": (SemanticSegmentationTask, ChesapeakeCVPRDataModule), | ||
"cowc_counting": (RegressionTask, COWCCountingDataModule), | ||
"cyclone": (RegressionTask, CycloneDataModule), | ||
"eurosat": (ClassificationTask, EuroSATDataModule), | ||
"etci2021": (SemanticSegmentationTask, ETCI2021DataModule), | ||
"inria": (SemanticSegmentationTask, InriaAerialImageLabelingDataModule), | ||
"landcoverai": (SemanticSegmentationTask, LandCoverAIDataModule), | ||
"naipchesapeake": (SemanticSegmentationTask, NAIPChesapeakeDataModule), | ||
"oscd": (SemanticSegmentationTask, OSCDDataModule), | ||
"resisc45": (ClassificationTask, RESISC45DataModule), | ||
"sen12ms": (SemanticSegmentationTask, SEN12MSDataModule), | ||
"so2sat": (ClassificationTask, So2SatDataModule), | ||
"ucmerced": (ClassificationTask, UCMercedDataModule), | ||
} | ||
|
||
|
||
def write_mask(mask: torch.Tensor, output_dir: str, input_filename: str) -> None: | ||
"""Write mask to specified output directory with same filename as input raster. | ||
|
||
Args: | ||
mask (torch.Tensor): mask tensor | ||
output_dir (str): output directory | ||
input_filename (str): path to input raster | ||
""" | ||
output_path = os.path.join(output_dir, os.path.basename(input_filename)) | ||
with rio.open(input_filename) as src: | ||
profile = src.profile | ||
profile["count"] = 1 | ||
profile["dtype"] = "uint8" | ||
mask = mask.cpu().numpy() | ||
with rio.open(output_path, "w", **profile) as ds: | ||
ds.write(mask) | ||
|
||
|
||
def main(config_dir: str, predict_on: str, output_dir: str, device: str) -> None: | ||
"""Main inference loop. | ||
|
||
Args: | ||
config_dir (str): Path to config-dir to load config and ckpt | ||
predict_on (str): Directory/Dataset to run inference on | ||
output_dir (str): Path to output_directory to save predicted masks | ||
device (str): Choice of device. Must be in [cuda, cpu] | ||
|
||
Raises: | ||
ValueError: Raised if task name is not in TASK_TO_MODULES_MAPPING | ||
FileExistsError: Raised if specified output directory contains | ||
files and overwrite=False. | ||
""" | ||
os.makedirs(output_dir, exist_ok=True) | ||
|
||
# Load checkpoint and config | ||
conf = OmegaConf.load(os.path.join(config_dir, "experiment_config.yaml")) | ||
ckpt = os.path.join(config_dir, "last.ckpt") | ||
Comment on lines
+98
to
+99
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should these filenames be parameters? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The ckpt filename sure. Think the config filename ("experiment_config.yaml") is hard coded. |
||
|
||
# Load model | ||
task_name = conf.experiment.task | ||
datamodule: pl.LightningDataModule | ||
task: pl.LightningModule | ||
if task_name not in TASK_TO_MODULES_MAPPING: | ||
raise ValueError( | ||
f"experiment.task={task_name} is not recognized as a valid task" | ||
) | ||
task_class, datamodule_class = TASK_TO_MODULES_MAPPING[task_name] | ||
task = task_class.load_from_checkpoint(ckpt) | ||
task = task.to(device) | ||
task.eval() | ||
|
||
# Load datamodule and dataloader | ||
conf.experiment.datamodule["predict_on"] = predict_on | ||
datamodule = datamodule_class(**conf.experiment.datamodule) | ||
datamodule.setup() | ||
dataloader = datamodule.predict_dataloader() | ||
|
||
if len(os.listdir(output_dir)) > 0: | ||
if conf.program.overwrite: | ||
print( | ||
f"WARNING! The output directory, {output_dir}, already exists, " | ||
+ "we will overwrite data in it!" | ||
) | ||
else: | ||
raise FileExistsError( | ||
f"The predictions directory, {output_dir}, already exists and isn't " | ||
+ "empty. We don't want to overwrite any existing results, exiting..." | ||
) | ||
|
||
for i, batch in enumerate(dataloader): | ||
x = batch["image"].to(device) # (N, B, C, H, W) | ||
assert len(x.shape) in {4, 5} | ||
if len(x.shape) == 5: | ||
masks = [] | ||
|
||
def tensor_to_int( | ||
tensor_tuple: Tuple[torch.Tensor, ...] | ||
) -> Tuple[int, ...]: | ||
"""Convert tuple of tensors to tuple of ints.""" | ||
return tuple(int(i.item()) for i in tensor_tuple) | ||
|
||
original_shape = cast( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These shouldn't require a cast if everything is typed correctly. |
||
Tuple[int, int], tensor_to_int(batch["original_shape"]) | ||
) | ||
patch_shape = cast(Tuple[int, int], tensor_to_int(batch["patch_shape"])) | ||
padding = cast(Tuple[int, int], tensor_to_int(batch["padding"])) | ||
patch_combine = CombineTensorPatches( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you tell me more about how this function works? What happens if your patches have overlap? Is this sufficient to close #30 or do we need something more powerful/generic? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I've written a tutorial describing how extract and combine patches work here
Currently you can extract patches with overlap (via the stride parameter) but you can't merge them together. This is because
Based on my understanding of the scope of #30, I would say no. If the goal is to just enable users to extract patches and stitch them together, kornia's But for the alternate stitching techniques (like label averaging) mentioned in the paper referenced in #30, we might need something more powerful as |
||
original_size=original_shape, window_size=patch_shape, unpadding=padding | ||
) | ||
|
||
for tile in x: | ||
mask = task(tile) | ||
mask = mask.argmax(dim=1) | ||
masks.append(mask) | ||
|
||
masks_arr = torch.stack(masks, dim=0) | ||
masks_arr = masks_arr.unsqueeze(0) | ||
masks_combined = patch_combine(masks_arr)[0] | ||
filename = datamodule.predict_dataset.files[i]["image"] | ||
write_mask(masks_combined, output_dir, filename) | ||
else: | ||
mask = task(x) | ||
mask = mask.argmax(dim=1) | ||
filename = datamodule.predict_dataset.files[i]["image"] | ||
write_mask(mask, output_dir, filename) | ||
|
||
|
||
if __name__ == "__main__": | ||
# Taken from https://github.com/pangeo-data/cog-best-practices | ||
_rasterio_best_practices = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can also be moved to |
||
"GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR", | ||
"AWS_NO_SIGN_REQUEST": "YES", | ||
"GDAL_MAX_RAW_BLOCK_CACHE_SIZE": "200000000", | ||
"GDAL_SWATH_SIZE": "200000000", | ||
"VSI_CURL_CACHE_SIZE": "200000000", | ||
} | ||
os.environ.update(_rasterio_best_practices) | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--config-dir", | ||
type=str, | ||
required=True, | ||
help="Path to config-dir to load config and ckpt", | ||
) | ||
|
||
parser.add_argument( | ||
"--predict_on", | ||
type=str, | ||
required=True, | ||
help="Directory/Dataset to run inference on", | ||
) | ||
|
||
parser.add_argument( | ||
"--output-dir", | ||
type=str, | ||
required=True, | ||
help="Path to output_directory to save predicted mask geotiffs", | ||
) | ||
|
||
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this something we should let PyTorch Lightning handle? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While |
||
args = parser.parse_args() | ||
main(args.config_dir, args.predict_on, args.output_dir, args.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that this is in two different scripts, we should move it somewhere where both scripts can find it to avoid code duplication. How about a
torchgeo/common.py
file? I don't want to put it intorchgeo/__init__.py
because this will be sourced on every import.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like the idea of
torchgeo/common.py
. I can add that in a follow up PR.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to this