From c21b84757d7f52a081a47066efd59a9166eb8394 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 13 Apr 2023 19:06:38 +0000 Subject: [PATCH 01/24] add dense regression task --- torchgeo/trainers/__init__.py | 3 +- torchgeo/trainers/regression.py | 135 +++++++++++++++++++++++++++++++- 2 files changed, 135 insertions(+), 3 deletions(-) diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index 6240b53f681..d62a7cde076 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -6,12 +6,13 @@ from .byol import BYOLTask from .classification import ClassificationTask, MultiLabelClassificationTask from .detection import ObjectDetectionTask -from .regression import RegressionTask +from .regression import DenseRegressionTask, RegressionTask from .segmentation import SemanticSegmentationTask __all__ = ( "BYOLTask", "ClassificationTask", + "DenseRegressionTask", "MultiLabelClassificationTask", "ObjectDetectionTask", "RegressionTask", diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 45daf0788b7..30d46d6fda6 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -7,8 +7,10 @@ from typing import Any, cast import matplotlib.pyplot as plt +import segmentation_models_pytorch as smp import timm import torch +import torch.nn as nn import torch.nn.functional as F from lightning.pytorch import LightningModule from torch import Tensor @@ -17,7 +19,7 @@ from torchvision.models._api import WeightsEnum from ..datasets import unbind_samples -from ..models import get_weight +from ..models import FCN, get_weight from . import utils @@ -80,7 +82,11 @@ def __init__(self, **kwargs: Any) -> None: self.config_task() self.train_metrics = MetricCollection( - {"RMSE": MeanSquaredError(squared=False), "MAE": MeanAbsoluteError()}, + { + "RMSE": MeanSquaredError(squared=False), + "MSE": MeanSquaredError(squared=True), + "MAE": MeanAbsoluteError(), + }, prefix="train_", ) self.val_metrics = self.train_metrics.clone(prefix="val_") @@ -219,3 +225,128 @@ def configure_optimizers(self) -> dict[str, Any]: "monitor": "val_loss", }, } + + +class DenseRegressionTask(RegressionTask): + """LightningModule for dense regression of images. + + Supports `Segmentation Models Pytorch + `_ + as an architecture choice in combination with any of these + `TIMM backbones `_. + + .. versionadded:: 0.5 + """ + + def config_task(self) -> None: + """Configures the task based on kwargs parameters passed to the constructor.""" + if self.hyperparams["model"] == "unet": + self.model = smp.Unet( + encoder_name=self.hyperparams["backbone"], + encoder_weights=self.hyperparams["weights"], + in_channels=self.hyperparams["in_channels"], + classes=1, + ) + elif self.hyperparams["model"] == "deeplabv3+": + self.model = smp.DeepLabV3Plus( + encoder_name=self.hyperparams["backbone"], + encoder_weights=self.hyperparams["weights"], + in_channels=self.hyperparams["in_channels"], + classes=1, + ) + elif self.hyperparams["model"] == "fcn": + self.model = FCN( + in_channels=self.hyperparams["in_channels"], + classes=1, + num_filters=self.hyperparams["num_filters"], + ) + else: + raise ValueError( + f"Model type '{self.hyperparams['model']}' is not valid. " + f"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." + ) + + if self.hyperparams["loss"] == "mse": + self.loss = nn.MSELoss() + elif self.hyperparams["loss"] == "mae": + self.loss = nn.L1Loss() + else: + raise ValueError( + f"Loss type '{self.hyperparams['loss']}' is not valid. " + f"Currently, supports 'mse' or 'mae' loss." + ) + + def training_step(self, *args: Any, **kwargs: Any) -> Tensor: + """Compute and return the training loss. + + Args: + batch: the output of your DataLoader + + Returns: + training loss + """ + batch = args[0] + x = batch["image"] + y = batch["mask"] + y_hat = self(x) + + loss = self.loss(y_hat, y) + + self.log("train_loss", loss) # logging to TensorBoard + self.train_metrics(y_hat, y) + + return loss + + def validation_step(self, *args: Any, **kwargs: Any) -> None: + """Compute validation loss and log example predictions. + + Args: + batch: the output of your DataLoader + batch_idx: the index of this batch + """ + batch = args[0] + batch_idx = args[1] + x = batch["image"] + y = batch["mask"] + y_hat = self(x) + + loss = self.loss(y_hat, y) + self.log("val_loss", loss) + self.val_metrics(y_hat, y) + + if ( + batch_idx < 10 + and hasattr(self.trainer, "datamodule") + and self.logger + and hasattr(self.logger, "experiment") + and hasattr(self.logger.experiment, "add_figure") + ): + try: + datamodule = self.trainer.datamodule + batch["prediction"] = y_hat + for key in ["image", "mask", "prediction"]: + batch[key] = batch[key].cpu() + sample = unbind_samples(batch)[0] + fig = datamodule.plot(sample) + summary_writer = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + plt.close() + except ValueError: + pass + + def test_step(self, *args: Any, **kwargs: Any) -> None: + """Compute test loss. + + Args: + batch: the output of your DataLoader + """ + batch = args[0] + x = batch["image"] + y = batch["mask"] + y_hat = self(x) + + loss = self.loss(y_hat, y) + self.log("test_loss", loss) + self.test_metrics(y_hat, y) From 6a2c519f7af05e72de0b1ce01822a4e5fc49d302 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 13 Apr 2023 20:12:42 +0000 Subject: [PATCH 02/24] fix mypy --- torchgeo/trainers/regression.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 30d46d6fda6..205465e3d2a 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -266,6 +266,7 @@ def config_task(self) -> None: f"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." ) + self.loss: nn.Module if self.hyperparams["loss"] == "mse": self.loss = nn.MSELoss() elif self.hyperparams["loss"] == "mae": @@ -295,7 +296,7 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor: self.log("train_loss", loss) # logging to TensorBoard self.train_metrics(y_hat, y) - return loss + return cast(Tensor, loss) def validation_step(self, *args: Any, **kwargs: Any) -> None: """Compute validation loss and log example predictions. From 20dd757ca592741887e0fcd38d7bf5d5aff9a929 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 21 Apr 2023 21:00:38 +0000 Subject: [PATCH 03/24] refactor and rename to reduce code duplication --- torchgeo/trainers/regression.py | 130 ++++++++------------------------ 1 file changed, 31 insertions(+), 99 deletions(-) diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 205465e3d2a..944993ab32d 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -11,7 +11,6 @@ import timm import torch import torch.nn as nn -import torch.nn.functional as F from lightning.pytorch import LightningModule from torch import Tensor from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -37,8 +36,10 @@ class RegressionTask(LightningModule): # type: ignore[misc] print(timm.list_models()) """ - def config_task(self) -> None: - """Configures the task based on kwargs parameters.""" + target_key: str = "label" + + def config_model(self) -> None: + """Configures the model based on kwargs parameters.""" # Create model weights = self.hyperparams["weights"] self.model = timm.create_model( @@ -58,6 +59,21 @@ def config_task(self) -> None: state_dict = get_weight(weights).get_state_dict(progress=True) self.model = utils.load_state_dict(self.model, state_dict) + def config_task(self) -> None: + """Configures the task based on kwargs parameters.""" + self.config_model() + + self.loss: nn.Module + if self.hyperparams["loss"] == "mse": + self.loss = nn.MSELoss() + elif self.hyperparams["loss"] == "mae": + self.loss = nn.L1Loss() + else: + raise ValueError( + f"Loss type '{self.hyperparams['loss']}' is not valid. " + f"Currently, supports 'mse' or 'mae' loss." + ) + def __init__(self, **kwargs: Any) -> None: """Initialize a new LightningModule for training simple regression models. @@ -114,10 +130,10 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor: """ batch = args[0] x = batch["image"] - y = batch["label"].view(-1, 1) + y = batch[self.target_key] y_hat = self(x) - loss = F.mse_loss(y_hat, y) + loss = self.loss(y_hat, y) self.log("train_loss", loss) # logging to TensorBoard self.train_metrics(y_hat, y) @@ -139,10 +155,10 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: batch = args[0] batch_idx = args[1] x = batch["image"] - y = batch["label"].view(-1, 1) + y = batch[self.target_key] y_hat = self(x) - loss = F.mse_loss(y_hat, y) + loss = self.loss(y_hat, y) self.log("val_loss", loss) self.val_metrics(y_hat, y) @@ -156,7 +172,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: try: datamodule = self.trainer.datamodule batch["prediction"] = y_hat - for key in ["image", "label", "prediction"]: + for key in ["image", self.target_key, "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] fig = datamodule.plot(sample) @@ -181,10 +197,10 @@ def test_step(self, *args: Any, **kwargs: Any) -> None: """ batch = args[0] x = batch["image"] - y = batch["label"].view(-1, 1) + y = batch[self.target_key] y_hat = self(x) - loss = F.mse_loss(y_hat, y) + loss = self.loss(y_hat, y) self.log("test_loss", loss) self.test_metrics(y_hat, y) @@ -227,7 +243,7 @@ def configure_optimizers(self) -> dict[str, Any]: } -class DenseRegressionTask(RegressionTask): +class PixelwiseRegressionTask(RegressionTask): """LightningModule for dense regression of images. Supports `Segmentation Models Pytorch @@ -238,8 +254,10 @@ class DenseRegressionTask(RegressionTask): .. versionadded:: 0.5 """ - def config_task(self) -> None: - """Configures the task based on kwargs parameters passed to the constructor.""" + target_key: str = "mask" + + def config_model(self) -> None: + """Configures the model based on kwargs parameters.""" if self.hyperparams["model"] == "unet": self.model = smp.Unet( encoder_name=self.hyperparams["backbone"], @@ -265,89 +283,3 @@ def config_task(self) -> None: f"Model type '{self.hyperparams['model']}' is not valid. " f"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." ) - - self.loss: nn.Module - if self.hyperparams["loss"] == "mse": - self.loss = nn.MSELoss() - elif self.hyperparams["loss"] == "mae": - self.loss = nn.L1Loss() - else: - raise ValueError( - f"Loss type '{self.hyperparams['loss']}' is not valid. " - f"Currently, supports 'mse' or 'mae' loss." - ) - - def training_step(self, *args: Any, **kwargs: Any) -> Tensor: - """Compute and return the training loss. - - Args: - batch: the output of your DataLoader - - Returns: - training loss - """ - batch = args[0] - x = batch["image"] - y = batch["mask"] - y_hat = self(x) - - loss = self.loss(y_hat, y) - - self.log("train_loss", loss) # logging to TensorBoard - self.train_metrics(y_hat, y) - - return cast(Tensor, loss) - - def validation_step(self, *args: Any, **kwargs: Any) -> None: - """Compute validation loss and log example predictions. - - Args: - batch: the output of your DataLoader - batch_idx: the index of this batch - """ - batch = args[0] - batch_idx = args[1] - x = batch["image"] - y = batch["mask"] - y_hat = self(x) - - loss = self.loss(y_hat, y) - self.log("val_loss", loss) - self.val_metrics(y_hat, y) - - if ( - batch_idx < 10 - and hasattr(self.trainer, "datamodule") - and self.logger - and hasattr(self.logger, "experiment") - and hasattr(self.logger.experiment, "add_figure") - ): - try: - datamodule = self.trainer.datamodule - batch["prediction"] = y_hat - for key in ["image", "mask", "prediction"]: - batch[key] = batch[key].cpu() - sample = unbind_samples(batch)[0] - fig = datamodule.plot(sample) - summary_writer = self.logger.experiment - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - plt.close() - except ValueError: - pass - - def test_step(self, *args: Any, **kwargs: Any) -> None: - """Compute test loss. - - Args: - batch: the output of your DataLoader - """ - batch = args[0] - x = batch["image"] - y = batch["mask"] - y_hat = self(x) - - loss = self.loss(y_hat, y) - self.log("test_loss", loss) - self.test_metrics(y_hat, y) From 0942fd2e1104860c0574af2726b39d2b681ad1f8 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 21 Apr 2023 21:00:50 +0000 Subject: [PATCH 04/24] rename --- torchgeo/trainers/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index d62a7cde076..e1db43fd3e5 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -6,15 +6,15 @@ from .byol import BYOLTask from .classification import ClassificationTask, MultiLabelClassificationTask from .detection import ObjectDetectionTask -from .regression import DenseRegressionTask, RegressionTask +from .regression import PixelwiseRegressionTask, RegressionTask from .segmentation import SemanticSegmentationTask __all__ = ( "BYOLTask", "ClassificationTask", - "DenseRegressionTask", "MultiLabelClassificationTask", "ObjectDetectionTask", + "PixelwiseRegressionTask", "RegressionTask", "SemanticSegmentationTask", ) From 39700a8d17b7b7e25c2b6d7d6b171219818b02c2 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 21 Apr 2023 21:06:26 +0000 Subject: [PATCH 05/24] add loss to test configs --- tests/conf/cowc_counting.yaml | 1 + tests/conf/cyclone.yaml | 1 + tests/conf/skippd.yaml | 1 + tests/conf/sustainbench_crop_yield.yaml | 1 + 4 files changed, 4 insertions(+) diff --git a/tests/conf/cowc_counting.yaml b/tests/conf/cowc_counting.yaml index fc3218e8fef..041748de953 100644 --- a/tests/conf/cowc_counting.yaml +++ b/tests/conf/cowc_counting.yaml @@ -7,6 +7,7 @@ experiment: in_channels: 3 learning_rate: 1e-3 learning_rate_schedule_patience: 2 + loss: "mse" datamodule: root: "tests/data/cowc_counting" download: true diff --git a/tests/conf/cyclone.yaml b/tests/conf/cyclone.yaml index b3323d28999..29a601b3f3c 100644 --- a/tests/conf/cyclone.yaml +++ b/tests/conf/cyclone.yaml @@ -7,6 +7,7 @@ experiment: in_channels: 3 learning_rate: 1e-3 learning_rate_schedule_patience: 2 + loss: "mse" datamodule: root: "tests/data/cyclone" download: true diff --git a/tests/conf/skippd.yaml b/tests/conf/skippd.yaml index 8f1c1cb655f..2c50192665f 100644 --- a/tests/conf/skippd.yaml +++ b/tests/conf/skippd.yaml @@ -7,6 +7,7 @@ experiment: in_channels: 3 learning_rate: 1e-3 learning_rate_schedule_patience: 2 + loss: "mse" datamodule: root: "tests/data/skippd" download: true diff --git a/tests/conf/sustainbench_crop_yield.yaml b/tests/conf/sustainbench_crop_yield.yaml index 60903ea7d4c..c9cf58a53d8 100644 --- a/tests/conf/sustainbench_crop_yield.yaml +++ b/tests/conf/sustainbench_crop_yield.yaml @@ -7,6 +7,7 @@ experiment: in_channels: 9 learning_rate: 1e-3 learning_rate_schedule_patience: 2 + loss: "mse" datamodule: root: "tests/data/sustainbench_crop_yield" download: true From d6fefb6d8f82e9bfe99aa50e7df9a5a209945b33 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 21 Apr 2023 21:06:47 +0000 Subject: [PATCH 06/24] rename dense to pixelwise --- torchgeo/trainers/regression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 944993ab32d..70cfe1b7797 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -244,7 +244,7 @@ def configure_optimizers(self) -> dict[str, Any]: class PixelwiseRegressionTask(RegressionTask): - """LightningModule for dense regression of images. + """LightningModule for pixelwise regression of images. Supports `Segmentation Models Pytorch `_ From 12180b9529ffd2ae464ba8caaa02ddcf47e4af23 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 21 Apr 2023 21:14:15 +0000 Subject: [PATCH 07/24] if batch_size=1 then targets will have dims=1 but need dims=2 for loss --- tests/trainers/test_regression.py | 7 +++++++ torchgeo/trainers/regression.py | 12 ++++++++++++ 2 files changed, 19 insertions(+) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index f4210a7cfa9..e495054973a 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -103,6 +103,7 @@ def model_kwargs(self) -> dict[str, Any]: "weights": None, "num_outputs": 1, "in_channels": 3, + "loss": "mse", } @pytest.fixture( @@ -198,3 +199,9 @@ def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None max_epochs=1, ) trainer.predict(model=model, datamodule=datamodule) + + def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None: + model_kwargs["loss"] = "invalid_loss" + match = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=match): + RegressionTask(**model_kwargs) diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 70cfe1b7797..1cb00c04a18 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -131,6 +131,10 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor: batch = args[0] x = batch["image"] y = batch[self.target_key] + + if y.ndim == 1: + y = y.unsqueeze(dim=1) + y_hat = self(x) loss = self.loss(y_hat, y) @@ -156,6 +160,10 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: batch_idx = args[1] x = batch["image"] y = batch[self.target_key] + + if y.ndim == 1: + y = y.unsqueeze(dim=1) + y_hat = self(x) loss = self.loss(y_hat, y) @@ -198,6 +206,10 @@ def test_step(self, *args: Any, **kwargs: Any) -> None: batch = args[0] x = batch["image"] y = batch[self.target_key] + + if y.ndim == 1: + y = y.unsqueeze(dim=1) + y_hat = self(x) loss = self.loss(y_hat, y) From 0b81934754bcce94ed24cd25fac331d71cbeee75 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 21 Apr 2023 21:57:44 +0000 Subject: [PATCH 08/24] add tests' --- tests/trainers/test_regression.py | 85 ++++++++++++++++++++++++++++++- torchgeo/trainers/regression.py | 31 ++++++----- 2 files changed, 99 insertions(+), 17 deletions(-) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index e495054973a..1b31d4108ca 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -6,17 +6,21 @@ from typing import Any, cast import pytest +import segmentation_models_pytorch as smp import timm import torch +import torch.nn as nn import torchvision from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from lightning.pytorch import LightningDataModule, Trainer from omegaconf import OmegaConf +from torch.nn.modules import Module from torchvision.models._api import WeightsEnum from torchgeo.datamodules import ( COWCCountingDataModule, + InriaAerialImageLabelingDataModule, MisconfigurationException, SKIPPDDataModule, SustainBenchCropYieldDataModule, @@ -24,11 +28,22 @@ ) from torchgeo.datasets import TropicalCyclone from torchgeo.models import get_model_weights, list_models -from torchgeo.trainers import RegressionTask +from torchgeo.trainers import PixelwiseRegressionTask, RegressionTask from .test_classification import ClassificationTestModel +class PixelwiseRegressionTestModel(Module): + def __init__(self, in_channels: int = 3, classes: int = 1, **kwargs: Any) -> None: + super().__init__() + self.conv1 = nn.Conv2d( + in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return cast(torch.Tensor, self.conv1(x)) + + class RegressionTestModel(ClassificationTestModel): def __init__(self, in_chans: int = 3, num_classes: int = 1, **kwargs: Any) -> None: super().__init__(in_chans=in_chans, num_classes=num_classes) @@ -48,6 +63,10 @@ def plot(*args: Any, **kwargs: Any) -> None: raise ValueError +def create_model(**kwargs: Any) -> Module: + return PixelwiseRegressionTestModel(**kwargs) + + class TestRegressionTask: @pytest.mark.parametrize( "name,classname", @@ -205,3 +224,67 @@ def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None: match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): RegressionTask(**model_kwargs) + + +class TestPixelwiseRegressionTask: + @pytest.mark.parametrize( + "name,classname,batch_size,loss", + [ + ("inria", InriaAerialImageLabelingDataModule, 1, "mse"), + ("inria", InriaAerialImageLabelingDataModule, 2, "mae"), + ], + ) + def test_trainer( + self, + monkeypatch: MonkeyPatch, + name: str, + classname: type[LightningDataModule], + batch_size: int, + loss: str, + fast_dev_run: bool, + ) -> None: + conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) + conf_dict = OmegaConf.to_object(conf.experiment) + conf_dict = cast(dict[str, dict[str, Any]], conf_dict) + + # Instantiate datamodule + datamodule_kwargs = conf_dict["datamodule"] + datamodule_kwargs["batch_size"] = batch_size + datamodule = classname(**datamodule_kwargs) + + # Instantiate model + monkeypatch.setattr(smp, "Unet", create_model) + monkeypatch.setattr(smp, "DeepLabV3Plus", create_model) + model_kwargs = conf_dict["module"] + model_kwargs["loss"] = loss + model = PixelwiseRegressionTask(**model_kwargs) + + model.model = PixelwiseRegressionTestModel(in_chans=model_kwargs["in_channels"]) + + # Instantiate trainer + trainer = Trainer( + accelerator="cpu", + fast_dev_run=fast_dev_run, + log_every_n_steps=1, + max_epochs=1, + ) + + trainer.fit(model=model, datamodule=datamodule) + try: + trainer.test(model=model, datamodule=datamodule) + except MisconfigurationException: + pass + try: + trainer.predict(model=model, datamodule=datamodule) + except MisconfigurationException: + pass + + @pytest.fixture + def model_kwargs(self) -> dict[str, Any]: + return { + "model": "resnet18", + "weights": None, + "num_outputs": 1, + "in_channels": 3, + "loss": "mse", + } diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 1cb00c04a18..23d1754d25e 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -131,16 +131,14 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor: batch = args[0] x = batch["image"] y = batch[self.target_key] - - if y.ndim == 1: - y = y.unsqueeze(dim=1) - y_hat = self(x) - loss = self.loss(y_hat, y) + if y_hat.ndim != y.ndim: + y = y.unsqueeze(dim=1) + loss = self.loss(y_hat, y.to(torch.float)) self.log("train_loss", loss) # logging to TensorBoard - self.train_metrics(y_hat, y) + self.train_metrics(y_hat, y.to(torch.float)) return loss @@ -160,15 +158,14 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: batch_idx = args[1] x = batch["image"] y = batch[self.target_key] + y_hat = self(x) - if y.ndim == 1: + if y_hat.ndim != y.ndim: y = y.unsqueeze(dim=1) - y_hat = self(x) - - loss = self.loss(y_hat, y) + loss = self.loss(y_hat, y.to(torch.float)) self.log("val_loss", loss) - self.val_metrics(y_hat, y) + self.val_metrics(y_hat, y.to(torch.float)) if ( batch_idx < 10 @@ -179,6 +176,9 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: ): try: datamodule = self.trainer.datamodule + if self.target_key == "mask": + y = y.squeeze(dim=1) + y_hat = y_hat.squeeze(dim=1) batch["prediction"] = y_hat for key in ["image", self.target_key, "prediction"]: batch[key] = batch[key].cpu() @@ -206,15 +206,14 @@ def test_step(self, *args: Any, **kwargs: Any) -> None: batch = args[0] x = batch["image"] y = batch[self.target_key] + y_hat = self(x) - if y.ndim == 1: + if y_hat.ndim != y.ndim: y = y.unsqueeze(dim=1) - y_hat = self(x) - - loss = self.loss(y_hat, y) + loss = self.loss(y_hat, y.to(torch.float)) self.log("test_loss", loss) - self.test_metrics(y_hat, y) + self.test_metrics(y_hat, y.to(torch.float)) def on_test_epoch_end(self) -> None: """Logs epoch level test metrics.""" From ba565e4090adb210cc470b2195f592bcbbe1026e Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 21 Apr 2023 22:47:52 +0000 Subject: [PATCH 09/24] fix mypy errors --- torchgeo/trainers/regression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 23d1754d25e..c20b556e565 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -136,7 +136,7 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor: if y_hat.ndim != y.ndim: y = y.unsqueeze(dim=1) - loss = self.loss(y_hat, y.to(torch.float)) + loss: Tensor = self.loss(y_hat, y.to(torch.float)) self.log("train_loss", loss) # logging to TensorBoard self.train_metrics(y_hat, y.to(torch.float)) From ffda89e01c3ee3a99cbc00fb33380bf7d52354d7 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 21 Apr 2023 22:58:44 +0000 Subject: [PATCH 10/24] get full coverage --- tests/trainers/test_regression.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 1b31d4108ca..75835bcac02 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -228,10 +228,11 @@ def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None: class TestPixelwiseRegressionTask: @pytest.mark.parametrize( - "name,classname,batch_size,loss", + "name,classname,batch_size,loss,model_type", [ - ("inria", InriaAerialImageLabelingDataModule, 1, "mse"), - ("inria", InriaAerialImageLabelingDataModule, 2, "mae"), + ("inria", InriaAerialImageLabelingDataModule, 1, "mse", "unet"), + ("inria", InriaAerialImageLabelingDataModule, 2, "mae", "deeplabv3+"), + ("inria", InriaAerialImageLabelingDataModule, 1, "mse", "fcn"), ], ) def test_trainer( @@ -241,6 +242,7 @@ def test_trainer( classname: type[LightningDataModule], batch_size: int, loss: str, + model_type: str, fast_dev_run: bool, ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) @@ -257,6 +259,11 @@ def test_trainer( monkeypatch.setattr(smp, "DeepLabV3Plus", create_model) model_kwargs = conf_dict["module"] model_kwargs["loss"] = loss + model_kwargs["model"] = model_type + + if model_type == "fcn": + model_kwargs["num_filters"] = 2 + model = PixelwiseRegressionTask(**model_kwargs) model.model = PixelwiseRegressionTestModel(in_chans=model_kwargs["in_channels"]) From e05a3b563671f6e24f467a2e3e6965091e556d6b Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 21 Apr 2023 23:28:56 +0000 Subject: [PATCH 11/24] test for invalid model --- tests/trainers/test_regression.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 75835bcac02..22bc5abd3d4 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -286,6 +286,12 @@ def test_trainer( except MisconfigurationException: pass + def test_invalid_model(self, model_kwargs: dict[str, Any]) -> None: + model_kwargs["model"] = "invalid_model" + match = "Model type 'invalid_model' is not valid." + with pytest.raises(ValueError, match=match): + PixelwiseRegressionTask(**model_kwargs) + @pytest.fixture def model_kwargs(self) -> dict[str, Any]: return { From eff09179c33e94efe1b36e56307b8b7021bfaa64 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 13 Apr 2023 19:06:38 +0000 Subject: [PATCH 12/24] add dense regression task --- torchgeo/trainers/__init__.py | 3 +- torchgeo/trainers/regression.py | 135 +++++++++++++++++++++++++++++++- 2 files changed, 135 insertions(+), 3 deletions(-) diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index 6240b53f681..d62a7cde076 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -6,12 +6,13 @@ from .byol import BYOLTask from .classification import ClassificationTask, MultiLabelClassificationTask from .detection import ObjectDetectionTask -from .regression import RegressionTask +from .regression import DenseRegressionTask, RegressionTask from .segmentation import SemanticSegmentationTask __all__ = ( "BYOLTask", "ClassificationTask", + "DenseRegressionTask", "MultiLabelClassificationTask", "ObjectDetectionTask", "RegressionTask", diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 45daf0788b7..30d46d6fda6 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -7,8 +7,10 @@ from typing import Any, cast import matplotlib.pyplot as plt +import segmentation_models_pytorch as smp import timm import torch +import torch.nn as nn import torch.nn.functional as F from lightning.pytorch import LightningModule from torch import Tensor @@ -17,7 +19,7 @@ from torchvision.models._api import WeightsEnum from ..datasets import unbind_samples -from ..models import get_weight +from ..models import FCN, get_weight from . import utils @@ -80,7 +82,11 @@ def __init__(self, **kwargs: Any) -> None: self.config_task() self.train_metrics = MetricCollection( - {"RMSE": MeanSquaredError(squared=False), "MAE": MeanAbsoluteError()}, + { + "RMSE": MeanSquaredError(squared=False), + "MSE": MeanSquaredError(squared=True), + "MAE": MeanAbsoluteError(), + }, prefix="train_", ) self.val_metrics = self.train_metrics.clone(prefix="val_") @@ -219,3 +225,128 @@ def configure_optimizers(self) -> dict[str, Any]: "monitor": "val_loss", }, } + + +class DenseRegressionTask(RegressionTask): + """LightningModule for dense regression of images. + + Supports `Segmentation Models Pytorch + `_ + as an architecture choice in combination with any of these + `TIMM backbones `_. + + .. versionadded:: 0.5 + """ + + def config_task(self) -> None: + """Configures the task based on kwargs parameters passed to the constructor.""" + if self.hyperparams["model"] == "unet": + self.model = smp.Unet( + encoder_name=self.hyperparams["backbone"], + encoder_weights=self.hyperparams["weights"], + in_channels=self.hyperparams["in_channels"], + classes=1, + ) + elif self.hyperparams["model"] == "deeplabv3+": + self.model = smp.DeepLabV3Plus( + encoder_name=self.hyperparams["backbone"], + encoder_weights=self.hyperparams["weights"], + in_channels=self.hyperparams["in_channels"], + classes=1, + ) + elif self.hyperparams["model"] == "fcn": + self.model = FCN( + in_channels=self.hyperparams["in_channels"], + classes=1, + num_filters=self.hyperparams["num_filters"], + ) + else: + raise ValueError( + f"Model type '{self.hyperparams['model']}' is not valid. " + f"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." + ) + + if self.hyperparams["loss"] == "mse": + self.loss = nn.MSELoss() + elif self.hyperparams["loss"] == "mae": + self.loss = nn.L1Loss() + else: + raise ValueError( + f"Loss type '{self.hyperparams['loss']}' is not valid. " + f"Currently, supports 'mse' or 'mae' loss." + ) + + def training_step(self, *args: Any, **kwargs: Any) -> Tensor: + """Compute and return the training loss. + + Args: + batch: the output of your DataLoader + + Returns: + training loss + """ + batch = args[0] + x = batch["image"] + y = batch["mask"] + y_hat = self(x) + + loss = self.loss(y_hat, y) + + self.log("train_loss", loss) # logging to TensorBoard + self.train_metrics(y_hat, y) + + return loss + + def validation_step(self, *args: Any, **kwargs: Any) -> None: + """Compute validation loss and log example predictions. + + Args: + batch: the output of your DataLoader + batch_idx: the index of this batch + """ + batch = args[0] + batch_idx = args[1] + x = batch["image"] + y = batch["mask"] + y_hat = self(x) + + loss = self.loss(y_hat, y) + self.log("val_loss", loss) + self.val_metrics(y_hat, y) + + if ( + batch_idx < 10 + and hasattr(self.trainer, "datamodule") + and self.logger + and hasattr(self.logger, "experiment") + and hasattr(self.logger.experiment, "add_figure") + ): + try: + datamodule = self.trainer.datamodule + batch["prediction"] = y_hat + for key in ["image", "mask", "prediction"]: + batch[key] = batch[key].cpu() + sample = unbind_samples(batch)[0] + fig = datamodule.plot(sample) + summary_writer = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + plt.close() + except ValueError: + pass + + def test_step(self, *args: Any, **kwargs: Any) -> None: + """Compute test loss. + + Args: + batch: the output of your DataLoader + """ + batch = args[0] + x = batch["image"] + y = batch["mask"] + y_hat = self(x) + + loss = self.loss(y_hat, y) + self.log("test_loss", loss) + self.test_metrics(y_hat, y) From 298d7bb6e7af22a7f591a1a53eee12e4cdc84172 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 13 Apr 2023 20:12:42 +0000 Subject: [PATCH 13/24] fix mypy --- torchgeo/trainers/regression.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 30d46d6fda6..205465e3d2a 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -266,6 +266,7 @@ def config_task(self) -> None: f"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." ) + self.loss: nn.Module if self.hyperparams["loss"] == "mse": self.loss = nn.MSELoss() elif self.hyperparams["loss"] == "mae": @@ -295,7 +296,7 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor: self.log("train_loss", loss) # logging to TensorBoard self.train_metrics(y_hat, y) - return loss + return cast(Tensor, loss) def validation_step(self, *args: Any, **kwargs: Any) -> None: """Compute validation loss and log example predictions. From 7cf0e2653e1e8bd916e28995955a01cce7723800 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 21 Apr 2023 21:00:38 +0000 Subject: [PATCH 14/24] refactor and rename to reduce code duplication --- torchgeo/trainers/regression.py | 130 ++++++++------------------------ 1 file changed, 31 insertions(+), 99 deletions(-) diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 205465e3d2a..944993ab32d 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -11,7 +11,6 @@ import timm import torch import torch.nn as nn -import torch.nn.functional as F from lightning.pytorch import LightningModule from torch import Tensor from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -37,8 +36,10 @@ class RegressionTask(LightningModule): # type: ignore[misc] print(timm.list_models()) """ - def config_task(self) -> None: - """Configures the task based on kwargs parameters.""" + target_key: str = "label" + + def config_model(self) -> None: + """Configures the model based on kwargs parameters.""" # Create model weights = self.hyperparams["weights"] self.model = timm.create_model( @@ -58,6 +59,21 @@ def config_task(self) -> None: state_dict = get_weight(weights).get_state_dict(progress=True) self.model = utils.load_state_dict(self.model, state_dict) + def config_task(self) -> None: + """Configures the task based on kwargs parameters.""" + self.config_model() + + self.loss: nn.Module + if self.hyperparams["loss"] == "mse": + self.loss = nn.MSELoss() + elif self.hyperparams["loss"] == "mae": + self.loss = nn.L1Loss() + else: + raise ValueError( + f"Loss type '{self.hyperparams['loss']}' is not valid. " + f"Currently, supports 'mse' or 'mae' loss." + ) + def __init__(self, **kwargs: Any) -> None: """Initialize a new LightningModule for training simple regression models. @@ -114,10 +130,10 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor: """ batch = args[0] x = batch["image"] - y = batch["label"].view(-1, 1) + y = batch[self.target_key] y_hat = self(x) - loss = F.mse_loss(y_hat, y) + loss = self.loss(y_hat, y) self.log("train_loss", loss) # logging to TensorBoard self.train_metrics(y_hat, y) @@ -139,10 +155,10 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: batch = args[0] batch_idx = args[1] x = batch["image"] - y = batch["label"].view(-1, 1) + y = batch[self.target_key] y_hat = self(x) - loss = F.mse_loss(y_hat, y) + loss = self.loss(y_hat, y) self.log("val_loss", loss) self.val_metrics(y_hat, y) @@ -156,7 +172,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: try: datamodule = self.trainer.datamodule batch["prediction"] = y_hat - for key in ["image", "label", "prediction"]: + for key in ["image", self.target_key, "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] fig = datamodule.plot(sample) @@ -181,10 +197,10 @@ def test_step(self, *args: Any, **kwargs: Any) -> None: """ batch = args[0] x = batch["image"] - y = batch["label"].view(-1, 1) + y = batch[self.target_key] y_hat = self(x) - loss = F.mse_loss(y_hat, y) + loss = self.loss(y_hat, y) self.log("test_loss", loss) self.test_metrics(y_hat, y) @@ -227,7 +243,7 @@ def configure_optimizers(self) -> dict[str, Any]: } -class DenseRegressionTask(RegressionTask): +class PixelwiseRegressionTask(RegressionTask): """LightningModule for dense regression of images. Supports `Segmentation Models Pytorch @@ -238,8 +254,10 @@ class DenseRegressionTask(RegressionTask): .. versionadded:: 0.5 """ - def config_task(self) -> None: - """Configures the task based on kwargs parameters passed to the constructor.""" + target_key: str = "mask" + + def config_model(self) -> None: + """Configures the model based on kwargs parameters.""" if self.hyperparams["model"] == "unet": self.model = smp.Unet( encoder_name=self.hyperparams["backbone"], @@ -265,89 +283,3 @@ def config_task(self) -> None: f"Model type '{self.hyperparams['model']}' is not valid. " f"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." ) - - self.loss: nn.Module - if self.hyperparams["loss"] == "mse": - self.loss = nn.MSELoss() - elif self.hyperparams["loss"] == "mae": - self.loss = nn.L1Loss() - else: - raise ValueError( - f"Loss type '{self.hyperparams['loss']}' is not valid. " - f"Currently, supports 'mse' or 'mae' loss." - ) - - def training_step(self, *args: Any, **kwargs: Any) -> Tensor: - """Compute and return the training loss. - - Args: - batch: the output of your DataLoader - - Returns: - training loss - """ - batch = args[0] - x = batch["image"] - y = batch["mask"] - y_hat = self(x) - - loss = self.loss(y_hat, y) - - self.log("train_loss", loss) # logging to TensorBoard - self.train_metrics(y_hat, y) - - return cast(Tensor, loss) - - def validation_step(self, *args: Any, **kwargs: Any) -> None: - """Compute validation loss and log example predictions. - - Args: - batch: the output of your DataLoader - batch_idx: the index of this batch - """ - batch = args[0] - batch_idx = args[1] - x = batch["image"] - y = batch["mask"] - y_hat = self(x) - - loss = self.loss(y_hat, y) - self.log("val_loss", loss) - self.val_metrics(y_hat, y) - - if ( - batch_idx < 10 - and hasattr(self.trainer, "datamodule") - and self.logger - and hasattr(self.logger, "experiment") - and hasattr(self.logger.experiment, "add_figure") - ): - try: - datamodule = self.trainer.datamodule - batch["prediction"] = y_hat - for key in ["image", "mask", "prediction"]: - batch[key] = batch[key].cpu() - sample = unbind_samples(batch)[0] - fig = datamodule.plot(sample) - summary_writer = self.logger.experiment - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - plt.close() - except ValueError: - pass - - def test_step(self, *args: Any, **kwargs: Any) -> None: - """Compute test loss. - - Args: - batch: the output of your DataLoader - """ - batch = args[0] - x = batch["image"] - y = batch["mask"] - y_hat = self(x) - - loss = self.loss(y_hat, y) - self.log("test_loss", loss) - self.test_metrics(y_hat, y) From 381a9d9bde169e413440c69c800d150479f04829 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 21 Apr 2023 21:00:50 +0000 Subject: [PATCH 15/24] rename --- torchgeo/trainers/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index d62a7cde076..e1db43fd3e5 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -6,15 +6,15 @@ from .byol import BYOLTask from .classification import ClassificationTask, MultiLabelClassificationTask from .detection import ObjectDetectionTask -from .regression import DenseRegressionTask, RegressionTask +from .regression import PixelwiseRegressionTask, RegressionTask from .segmentation import SemanticSegmentationTask __all__ = ( "BYOLTask", "ClassificationTask", - "DenseRegressionTask", "MultiLabelClassificationTask", "ObjectDetectionTask", + "PixelwiseRegressionTask", "RegressionTask", "SemanticSegmentationTask", ) From a36d74d31f7ea5d4e3b9e2426afc5f62eeeb6afb Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 21 Apr 2023 21:06:26 +0000 Subject: [PATCH 16/24] add loss to test configs --- tests/conf/cowc_counting.yaml | 1 + tests/conf/cyclone.yaml | 1 + tests/conf/skippd.yaml | 1 + tests/conf/sustainbench_crop_yield.yaml | 1 + 4 files changed, 4 insertions(+) diff --git a/tests/conf/cowc_counting.yaml b/tests/conf/cowc_counting.yaml index fc3218e8fef..041748de953 100644 --- a/tests/conf/cowc_counting.yaml +++ b/tests/conf/cowc_counting.yaml @@ -7,6 +7,7 @@ experiment: in_channels: 3 learning_rate: 1e-3 learning_rate_schedule_patience: 2 + loss: "mse" datamodule: root: "tests/data/cowc_counting" download: true diff --git a/tests/conf/cyclone.yaml b/tests/conf/cyclone.yaml index b3323d28999..29a601b3f3c 100644 --- a/tests/conf/cyclone.yaml +++ b/tests/conf/cyclone.yaml @@ -7,6 +7,7 @@ experiment: in_channels: 3 learning_rate: 1e-3 learning_rate_schedule_patience: 2 + loss: "mse" datamodule: root: "tests/data/cyclone" download: true diff --git a/tests/conf/skippd.yaml b/tests/conf/skippd.yaml index 8f1c1cb655f..2c50192665f 100644 --- a/tests/conf/skippd.yaml +++ b/tests/conf/skippd.yaml @@ -7,6 +7,7 @@ experiment: in_channels: 3 learning_rate: 1e-3 learning_rate_schedule_patience: 2 + loss: "mse" datamodule: root: "tests/data/skippd" download: true diff --git a/tests/conf/sustainbench_crop_yield.yaml b/tests/conf/sustainbench_crop_yield.yaml index 60903ea7d4c..c9cf58a53d8 100644 --- a/tests/conf/sustainbench_crop_yield.yaml +++ b/tests/conf/sustainbench_crop_yield.yaml @@ -7,6 +7,7 @@ experiment: in_channels: 9 learning_rate: 1e-3 learning_rate_schedule_patience: 2 + loss: "mse" datamodule: root: "tests/data/sustainbench_crop_yield" download: true From 2ae4d8b84ee275f7f34ad9f02659e1d89bff6ba0 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 21 Apr 2023 21:06:47 +0000 Subject: [PATCH 17/24] rename dense to pixelwise --- torchgeo/trainers/regression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 944993ab32d..70cfe1b7797 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -244,7 +244,7 @@ def configure_optimizers(self) -> dict[str, Any]: class PixelwiseRegressionTask(RegressionTask): - """LightningModule for dense regression of images. + """LightningModule for pixelwise regression of images. Supports `Segmentation Models Pytorch `_ From e6fbbdb26d09107c91327d256b78dc62cb7458c2 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 21 Apr 2023 21:14:15 +0000 Subject: [PATCH 18/24] if batch_size=1 then targets will have dims=1 but need dims=2 for loss --- tests/trainers/test_regression.py | 7 +++++++ torchgeo/trainers/regression.py | 12 ++++++++++++ 2 files changed, 19 insertions(+) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index f4210a7cfa9..e495054973a 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -103,6 +103,7 @@ def model_kwargs(self) -> dict[str, Any]: "weights": None, "num_outputs": 1, "in_channels": 3, + "loss": "mse", } @pytest.fixture( @@ -198,3 +199,9 @@ def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None max_epochs=1, ) trainer.predict(model=model, datamodule=datamodule) + + def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None: + model_kwargs["loss"] = "invalid_loss" + match = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=match): + RegressionTask(**model_kwargs) diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 70cfe1b7797..1cb00c04a18 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -131,6 +131,10 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor: batch = args[0] x = batch["image"] y = batch[self.target_key] + + if y.ndim == 1: + y = y.unsqueeze(dim=1) + y_hat = self(x) loss = self.loss(y_hat, y) @@ -156,6 +160,10 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: batch_idx = args[1] x = batch["image"] y = batch[self.target_key] + + if y.ndim == 1: + y = y.unsqueeze(dim=1) + y_hat = self(x) loss = self.loss(y_hat, y) @@ -198,6 +206,10 @@ def test_step(self, *args: Any, **kwargs: Any) -> None: batch = args[0] x = batch["image"] y = batch[self.target_key] + + if y.ndim == 1: + y = y.unsqueeze(dim=1) + y_hat = self(x) loss = self.loss(y_hat, y) From 8741d6bae2bf5d7903a5c767bcd1bb7801a61aa0 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 21 Apr 2023 21:57:44 +0000 Subject: [PATCH 19/24] add tests' --- tests/trainers/test_regression.py | 85 ++++++++++++++++++++++++++++++- torchgeo/trainers/regression.py | 31 ++++++----- 2 files changed, 99 insertions(+), 17 deletions(-) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index e495054973a..1b31d4108ca 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -6,17 +6,21 @@ from typing import Any, cast import pytest +import segmentation_models_pytorch as smp import timm import torch +import torch.nn as nn import torchvision from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from lightning.pytorch import LightningDataModule, Trainer from omegaconf import OmegaConf +from torch.nn.modules import Module from torchvision.models._api import WeightsEnum from torchgeo.datamodules import ( COWCCountingDataModule, + InriaAerialImageLabelingDataModule, MisconfigurationException, SKIPPDDataModule, SustainBenchCropYieldDataModule, @@ -24,11 +28,22 @@ ) from torchgeo.datasets import TropicalCyclone from torchgeo.models import get_model_weights, list_models -from torchgeo.trainers import RegressionTask +from torchgeo.trainers import PixelwiseRegressionTask, RegressionTask from .test_classification import ClassificationTestModel +class PixelwiseRegressionTestModel(Module): + def __init__(self, in_channels: int = 3, classes: int = 1, **kwargs: Any) -> None: + super().__init__() + self.conv1 = nn.Conv2d( + in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return cast(torch.Tensor, self.conv1(x)) + + class RegressionTestModel(ClassificationTestModel): def __init__(self, in_chans: int = 3, num_classes: int = 1, **kwargs: Any) -> None: super().__init__(in_chans=in_chans, num_classes=num_classes) @@ -48,6 +63,10 @@ def plot(*args: Any, **kwargs: Any) -> None: raise ValueError +def create_model(**kwargs: Any) -> Module: + return PixelwiseRegressionTestModel(**kwargs) + + class TestRegressionTask: @pytest.mark.parametrize( "name,classname", @@ -205,3 +224,67 @@ def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None: match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): RegressionTask(**model_kwargs) + + +class TestPixelwiseRegressionTask: + @pytest.mark.parametrize( + "name,classname,batch_size,loss", + [ + ("inria", InriaAerialImageLabelingDataModule, 1, "mse"), + ("inria", InriaAerialImageLabelingDataModule, 2, "mae"), + ], + ) + def test_trainer( + self, + monkeypatch: MonkeyPatch, + name: str, + classname: type[LightningDataModule], + batch_size: int, + loss: str, + fast_dev_run: bool, + ) -> None: + conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) + conf_dict = OmegaConf.to_object(conf.experiment) + conf_dict = cast(dict[str, dict[str, Any]], conf_dict) + + # Instantiate datamodule + datamodule_kwargs = conf_dict["datamodule"] + datamodule_kwargs["batch_size"] = batch_size + datamodule = classname(**datamodule_kwargs) + + # Instantiate model + monkeypatch.setattr(smp, "Unet", create_model) + monkeypatch.setattr(smp, "DeepLabV3Plus", create_model) + model_kwargs = conf_dict["module"] + model_kwargs["loss"] = loss + model = PixelwiseRegressionTask(**model_kwargs) + + model.model = PixelwiseRegressionTestModel(in_chans=model_kwargs["in_channels"]) + + # Instantiate trainer + trainer = Trainer( + accelerator="cpu", + fast_dev_run=fast_dev_run, + log_every_n_steps=1, + max_epochs=1, + ) + + trainer.fit(model=model, datamodule=datamodule) + try: + trainer.test(model=model, datamodule=datamodule) + except MisconfigurationException: + pass + try: + trainer.predict(model=model, datamodule=datamodule) + except MisconfigurationException: + pass + + @pytest.fixture + def model_kwargs(self) -> dict[str, Any]: + return { + "model": "resnet18", + "weights": None, + "num_outputs": 1, + "in_channels": 3, + "loss": "mse", + } diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 1cb00c04a18..23d1754d25e 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -131,16 +131,14 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor: batch = args[0] x = batch["image"] y = batch[self.target_key] - - if y.ndim == 1: - y = y.unsqueeze(dim=1) - y_hat = self(x) - loss = self.loss(y_hat, y) + if y_hat.ndim != y.ndim: + y = y.unsqueeze(dim=1) + loss = self.loss(y_hat, y.to(torch.float)) self.log("train_loss", loss) # logging to TensorBoard - self.train_metrics(y_hat, y) + self.train_metrics(y_hat, y.to(torch.float)) return loss @@ -160,15 +158,14 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: batch_idx = args[1] x = batch["image"] y = batch[self.target_key] + y_hat = self(x) - if y.ndim == 1: + if y_hat.ndim != y.ndim: y = y.unsqueeze(dim=1) - y_hat = self(x) - - loss = self.loss(y_hat, y) + loss = self.loss(y_hat, y.to(torch.float)) self.log("val_loss", loss) - self.val_metrics(y_hat, y) + self.val_metrics(y_hat, y.to(torch.float)) if ( batch_idx < 10 @@ -179,6 +176,9 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: ): try: datamodule = self.trainer.datamodule + if self.target_key == "mask": + y = y.squeeze(dim=1) + y_hat = y_hat.squeeze(dim=1) batch["prediction"] = y_hat for key in ["image", self.target_key, "prediction"]: batch[key] = batch[key].cpu() @@ -206,15 +206,14 @@ def test_step(self, *args: Any, **kwargs: Any) -> None: batch = args[0] x = batch["image"] y = batch[self.target_key] + y_hat = self(x) - if y.ndim == 1: + if y_hat.ndim != y.ndim: y = y.unsqueeze(dim=1) - y_hat = self(x) - - loss = self.loss(y_hat, y) + loss = self.loss(y_hat, y.to(torch.float)) self.log("test_loss", loss) - self.test_metrics(y_hat, y) + self.test_metrics(y_hat, y.to(torch.float)) def on_test_epoch_end(self) -> None: """Logs epoch level test metrics.""" From 74a14d6506dde9d2379a1d442f5eb0ccc08ad363 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 21 Apr 2023 22:47:52 +0000 Subject: [PATCH 20/24] fix mypy errors --- torchgeo/trainers/regression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 23d1754d25e..c20b556e565 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -136,7 +136,7 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor: if y_hat.ndim != y.ndim: y = y.unsqueeze(dim=1) - loss = self.loss(y_hat, y.to(torch.float)) + loss: Tensor = self.loss(y_hat, y.to(torch.float)) self.log("train_loss", loss) # logging to TensorBoard self.train_metrics(y_hat, y.to(torch.float)) From e8b001ff5867f2ae852ac7ef9a1f9b6d504e5ab9 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 21 Apr 2023 22:58:44 +0000 Subject: [PATCH 21/24] get full coverage --- tests/trainers/test_regression.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 1b31d4108ca..75835bcac02 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -228,10 +228,11 @@ def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None: class TestPixelwiseRegressionTask: @pytest.mark.parametrize( - "name,classname,batch_size,loss", + "name,classname,batch_size,loss,model_type", [ - ("inria", InriaAerialImageLabelingDataModule, 1, "mse"), - ("inria", InriaAerialImageLabelingDataModule, 2, "mae"), + ("inria", InriaAerialImageLabelingDataModule, 1, "mse", "unet"), + ("inria", InriaAerialImageLabelingDataModule, 2, "mae", "deeplabv3+"), + ("inria", InriaAerialImageLabelingDataModule, 1, "mse", "fcn"), ], ) def test_trainer( @@ -241,6 +242,7 @@ def test_trainer( classname: type[LightningDataModule], batch_size: int, loss: str, + model_type: str, fast_dev_run: bool, ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) @@ -257,6 +259,11 @@ def test_trainer( monkeypatch.setattr(smp, "DeepLabV3Plus", create_model) model_kwargs = conf_dict["module"] model_kwargs["loss"] = loss + model_kwargs["model"] = model_type + + if model_type == "fcn": + model_kwargs["num_filters"] = 2 + model = PixelwiseRegressionTask(**model_kwargs) model.model = PixelwiseRegressionTestModel(in_chans=model_kwargs["in_channels"]) From d71a2521e51b3342b32698850ba5bb2e02ae1c99 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 21 Apr 2023 23:28:56 +0000 Subject: [PATCH 22/24] test for invalid model --- tests/trainers/test_regression.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 75835bcac02..22bc5abd3d4 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -286,6 +286,12 @@ def test_trainer( except MisconfigurationException: pass + def test_invalid_model(self, model_kwargs: dict[str, Any]) -> None: + model_kwargs["model"] = "invalid_model" + match = "Model type 'invalid_model' is not valid." + with pytest.raises(ValueError, match=match): + PixelwiseRegressionTask(**model_kwargs) + @pytest.fixture def model_kwargs(self) -> dict[str, Any]: return { From 5bc1f5d315f8e4fae1574b5db3c436bd99ab9d48 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Mon, 24 Apr 2023 18:35:34 +0000 Subject: [PATCH 23/24] update tests --- tests/trainers/test_regression.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 5c09ea418e2..c83521365fc 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -3,7 +3,7 @@ import os from pathlib import Path -from typing import Any +from typing import Any, cast import pytest import segmentation_models_pytorch as smp @@ -209,45 +209,41 @@ def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None: class TestPixelwiseRegressionTask: @pytest.mark.parametrize( - "name,classname,batch_size,loss,model_type", + "name,batch_size,loss,model_type", [ - ("inria", InriaAerialImageLabelingDataModule, 1, "mse", "unet"), - ("inria", InriaAerialImageLabelingDataModule, 2, "mae", "deeplabv3+"), - ("inria", InriaAerialImageLabelingDataModule, 1, "mse", "fcn"), + ("inria", 1, "mse", "unet"), + ("inria", 2, "mae", "deeplabv3+"), + ("inria", 1, "mse", "fcn"), ], ) def test_trainer( self, monkeypatch: MonkeyPatch, name: str, - classname: type[LightningDataModule], batch_size: int, loss: str, model_type: str, fast_dev_run: bool, + model_kwargs: dict[str, Any], ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) - conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(dict[str, dict[str, Any]], conf_dict) # Instantiate datamodule - datamodule_kwargs = conf_dict["datamodule"] - datamodule_kwargs["batch_size"] = batch_size - datamodule = classname(**datamodule_kwargs) + conf.datamodule.batch_size = batch_size + datamodule = instantiate(conf.datamodule) # Instantiate model monkeypatch.setattr(smp, "Unet", create_model) monkeypatch.setattr(smp, "DeepLabV3Plus", create_model) - model_kwargs = conf_dict["module"] - model_kwargs["loss"] = loss model_kwargs["model"] = model_type if model_type == "fcn": model_kwargs["num_filters"] = 2 model = PixelwiseRegressionTask(**model_kwargs) - - model.model = PixelwiseRegressionTestModel(in_chans=model_kwargs["in_channels"]) + model.model = PixelwiseRegressionTestModel( + in_channels=model_kwargs["in_channels"] + ) # Instantiate trainer trainer = Trainer( @@ -276,9 +272,12 @@ def test_invalid_model(self, model_kwargs: dict[str, Any]) -> None: @pytest.fixture def model_kwargs(self) -> dict[str, Any]: return { - "model": "resnet18", + "model": "unet", + "backbone": "resnet18", "weights": None, "num_outputs": 1, "in_channels": 3, "loss": "mse", + "learning_rate": 1e-3, + "learning_rate_schedule_patience": 6, } From b135bdc09468237728dfe32d066dfc120b91fd1f Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Mon, 24 Apr 2023 19:11:37 +0000 Subject: [PATCH 24/24] 100% coverage --- tests/trainers/test_regression.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index c83521365fc..3479ef385e6 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -236,6 +236,7 @@ def test_trainer( monkeypatch.setattr(smp, "Unet", create_model) monkeypatch.setattr(smp, "DeepLabV3Plus", create_model) model_kwargs["model"] = model_type + model_kwargs["loss"] = loss if model_type == "fcn": model_kwargs["num_filters"] = 2