diff --git a/tests/conf/cowc_counting.yaml b/tests/conf/cowc_counting.yaml index 76eb04763a6..c5855bef5fb 100644 --- a/tests/conf/cowc_counting.yaml +++ b/tests/conf/cowc_counting.yaml @@ -6,6 +6,7 @@ module: in_channels: 3 learning_rate: 1e-3 learning_rate_schedule_patience: 2 + loss: "mse" datamodule: _target_: torchgeo.datamodules.COWCCountingDataModule diff --git a/tests/conf/cyclone.yaml b/tests/conf/cyclone.yaml index 91a477a144d..5b096dcfe7b 100644 --- a/tests/conf/cyclone.yaml +++ b/tests/conf/cyclone.yaml @@ -6,6 +6,7 @@ module: in_channels: 3 learning_rate: 1e-3 learning_rate_schedule_patience: 2 + loss: "mse" datamodule: _target_: torchgeo.datamodules.TropicalCycloneDataModule diff --git a/tests/conf/skippd.yaml b/tests/conf/skippd.yaml index 20ca10f24a6..6b1fdfdc22b 100644 --- a/tests/conf/skippd.yaml +++ b/tests/conf/skippd.yaml @@ -6,6 +6,7 @@ module: in_channels: 3 learning_rate: 1e-3 learning_rate_schedule_patience: 2 + loss: "mse" datamodule: _target_: torchgeo.datamodules.SKIPPDDataModule diff --git a/tests/conf/sustainbench_crop_yield.yaml b/tests/conf/sustainbench_crop_yield.yaml index 2f48a83d02a..09fbb37d05a 100644 --- a/tests/conf/sustainbench_crop_yield.yaml +++ b/tests/conf/sustainbench_crop_yield.yaml @@ -6,6 +6,7 @@ module: in_channels: 9 learning_rate: 1e-3 learning_rate_schedule_patience: 2 + loss: "mse" datamodule: _target_: torchgeo.datamodules.SustainBenchCropYieldDataModule diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index cabfe5a2cdc..3479ef385e6 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -3,27 +3,41 @@ import os from pathlib import Path -from typing import Any +from typing import Any, cast import pytest +import segmentation_models_pytorch as smp import timm import torch +import torch.nn as nn import torchvision from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from hydra.utils import instantiate from lightning.pytorch import Trainer from omegaconf import OmegaConf +from torch.nn.modules import Module from torchvision.models._api import WeightsEnum from torchgeo.datamodules import MisconfigurationException, TropicalCycloneDataModule from torchgeo.datasets import TropicalCyclone from torchgeo.models import get_model_weights, list_models -from torchgeo.trainers import RegressionTask +from torchgeo.trainers import PixelwiseRegressionTask, RegressionTask from .test_classification import ClassificationTestModel +class PixelwiseRegressionTestModel(Module): + def __init__(self, in_channels: int = 3, classes: int = 1, **kwargs: Any) -> None: + super().__init__() + self.conv1 = nn.Conv2d( + in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return cast(torch.Tensor, self.conv1(x)) + + class RegressionTestModel(ClassificationTestModel): def __init__(self, in_chans: int = 3, num_classes: int = 1, **kwargs: Any) -> None: super().__init__(in_chans=in_chans, num_classes=num_classes) @@ -43,6 +57,10 @@ def plot(*args: Any, **kwargs: Any) -> None: raise ValueError +def create_model(**kwargs: Any) -> Module: + return PixelwiseRegressionTestModel(**kwargs) + + class TestRegressionTask: @pytest.mark.parametrize( "name", ["cowc_counting", "cyclone", "sustainbench_crop_yield", "skippd"] @@ -85,6 +103,7 @@ def model_kwargs(self) -> dict[str, Any]: "weights": None, "num_outputs": 1, "in_channels": 3, + "loss": "mse", } @pytest.fixture( @@ -180,3 +199,86 @@ def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None max_epochs=1, ) trainer.predict(model=model, datamodule=datamodule) + + def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None: + model_kwargs["loss"] = "invalid_loss" + match = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=match): + RegressionTask(**model_kwargs) + + +class TestPixelwiseRegressionTask: + @pytest.mark.parametrize( + "name,batch_size,loss,model_type", + [ + ("inria", 1, "mse", "unet"), + ("inria", 2, "mae", "deeplabv3+"), + ("inria", 1, "mse", "fcn"), + ], + ) + def test_trainer( + self, + monkeypatch: MonkeyPatch, + name: str, + batch_size: int, + loss: str, + model_type: str, + fast_dev_run: bool, + model_kwargs: dict[str, Any], + ) -> None: + conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) + + # Instantiate datamodule + conf.datamodule.batch_size = batch_size + datamodule = instantiate(conf.datamodule) + + # Instantiate model + monkeypatch.setattr(smp, "Unet", create_model) + monkeypatch.setattr(smp, "DeepLabV3Plus", create_model) + model_kwargs["model"] = model_type + model_kwargs["loss"] = loss + + if model_type == "fcn": + model_kwargs["num_filters"] = 2 + + model = PixelwiseRegressionTask(**model_kwargs) + model.model = PixelwiseRegressionTestModel( + in_channels=model_kwargs["in_channels"] + ) + + # Instantiate trainer + trainer = Trainer( + accelerator="cpu", + fast_dev_run=fast_dev_run, + log_every_n_steps=1, + max_epochs=1, + ) + + trainer.fit(model=model, datamodule=datamodule) + try: + trainer.test(model=model, datamodule=datamodule) + except MisconfigurationException: + pass + try: + trainer.predict(model=model, datamodule=datamodule) + except MisconfigurationException: + pass + + def test_invalid_model(self, model_kwargs: dict[str, Any]) -> None: + model_kwargs["model"] = "invalid_model" + match = "Model type 'invalid_model' is not valid." + with pytest.raises(ValueError, match=match): + PixelwiseRegressionTask(**model_kwargs) + + @pytest.fixture + def model_kwargs(self) -> dict[str, Any]: + return { + "model": "unet", + "backbone": "resnet18", + "weights": None, + "num_outputs": 1, + "in_channels": 3, + "loss": "mse", + "learning_rate": 1e-3, + "learning_rate_schedule_patience": 6, + } diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index 6240b53f681..e1db43fd3e5 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -6,7 +6,7 @@ from .byol import BYOLTask from .classification import ClassificationTask, MultiLabelClassificationTask from .detection import ObjectDetectionTask -from .regression import RegressionTask +from .regression import PixelwiseRegressionTask, RegressionTask from .segmentation import SemanticSegmentationTask __all__ = ( @@ -14,6 +14,7 @@ "ClassificationTask", "MultiLabelClassificationTask", "ObjectDetectionTask", + "PixelwiseRegressionTask", "RegressionTask", "SemanticSegmentationTask", ) diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 45daf0788b7..c20b556e565 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -7,9 +7,10 @@ from typing import Any, cast import matplotlib.pyplot as plt +import segmentation_models_pytorch as smp import timm import torch -import torch.nn.functional as F +import torch.nn as nn from lightning.pytorch import LightningModule from torch import Tensor from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -17,7 +18,7 @@ from torchvision.models._api import WeightsEnum from ..datasets import unbind_samples -from ..models import get_weight +from ..models import FCN, get_weight from . import utils @@ -35,8 +36,10 @@ class RegressionTask(LightningModule): # type: ignore[misc] print(timm.list_models()) """ - def config_task(self) -> None: - """Configures the task based on kwargs parameters.""" + target_key: str = "label" + + def config_model(self) -> None: + """Configures the model based on kwargs parameters.""" # Create model weights = self.hyperparams["weights"] self.model = timm.create_model( @@ -56,6 +59,21 @@ def config_task(self) -> None: state_dict = get_weight(weights).get_state_dict(progress=True) self.model = utils.load_state_dict(self.model, state_dict) + def config_task(self) -> None: + """Configures the task based on kwargs parameters.""" + self.config_model() + + self.loss: nn.Module + if self.hyperparams["loss"] == "mse": + self.loss = nn.MSELoss() + elif self.hyperparams["loss"] == "mae": + self.loss = nn.L1Loss() + else: + raise ValueError( + f"Loss type '{self.hyperparams['loss']}' is not valid. " + f"Currently, supports 'mse' or 'mae' loss." + ) + def __init__(self, **kwargs: Any) -> None: """Initialize a new LightningModule for training simple regression models. @@ -80,7 +98,11 @@ def __init__(self, **kwargs: Any) -> None: self.config_task() self.train_metrics = MetricCollection( - {"RMSE": MeanSquaredError(squared=False), "MAE": MeanAbsoluteError()}, + { + "RMSE": MeanSquaredError(squared=False), + "MSE": MeanSquaredError(squared=True), + "MAE": MeanAbsoluteError(), + }, prefix="train_", ) self.val_metrics = self.train_metrics.clone(prefix="val_") @@ -108,13 +130,15 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor: """ batch = args[0] x = batch["image"] - y = batch["label"].view(-1, 1) + y = batch[self.target_key] y_hat = self(x) - loss = F.mse_loss(y_hat, y) + if y_hat.ndim != y.ndim: + y = y.unsqueeze(dim=1) + loss: Tensor = self.loss(y_hat, y.to(torch.float)) self.log("train_loss", loss) # logging to TensorBoard - self.train_metrics(y_hat, y) + self.train_metrics(y_hat, y.to(torch.float)) return loss @@ -133,12 +157,15 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: batch = args[0] batch_idx = args[1] x = batch["image"] - y = batch["label"].view(-1, 1) + y = batch[self.target_key] y_hat = self(x) - loss = F.mse_loss(y_hat, y) + if y_hat.ndim != y.ndim: + y = y.unsqueeze(dim=1) + + loss = self.loss(y_hat, y.to(torch.float)) self.log("val_loss", loss) - self.val_metrics(y_hat, y) + self.val_metrics(y_hat, y.to(torch.float)) if ( batch_idx < 10 @@ -149,8 +176,11 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: ): try: datamodule = self.trainer.datamodule + if self.target_key == "mask": + y = y.squeeze(dim=1) + y_hat = y_hat.squeeze(dim=1) batch["prediction"] = y_hat - for key in ["image", "label", "prediction"]: + for key in ["image", self.target_key, "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] fig = datamodule.plot(sample) @@ -175,12 +205,15 @@ def test_step(self, *args: Any, **kwargs: Any) -> None: """ batch = args[0] x = batch["image"] - y = batch["label"].view(-1, 1) + y = batch[self.target_key] y_hat = self(x) - loss = F.mse_loss(y_hat, y) + if y_hat.ndim != y.ndim: + y = y.unsqueeze(dim=1) + + loss = self.loss(y_hat, y.to(torch.float)) self.log("test_loss", loss) - self.test_metrics(y_hat, y) + self.test_metrics(y_hat, y.to(torch.float)) def on_test_epoch_end(self) -> None: """Logs epoch level test metrics.""" @@ -219,3 +252,45 @@ def configure_optimizers(self) -> dict[str, Any]: "monitor": "val_loss", }, } + + +class PixelwiseRegressionTask(RegressionTask): + """LightningModule for pixelwise regression of images. + + Supports `Segmentation Models Pytorch + `_ + as an architecture choice in combination with any of these + `TIMM backbones `_. + + .. versionadded:: 0.5 + """ + + target_key: str = "mask" + + def config_model(self) -> None: + """Configures the model based on kwargs parameters.""" + if self.hyperparams["model"] == "unet": + self.model = smp.Unet( + encoder_name=self.hyperparams["backbone"], + encoder_weights=self.hyperparams["weights"], + in_channels=self.hyperparams["in_channels"], + classes=1, + ) + elif self.hyperparams["model"] == "deeplabv3+": + self.model = smp.DeepLabV3Plus( + encoder_name=self.hyperparams["backbone"], + encoder_weights=self.hyperparams["weights"], + in_channels=self.hyperparams["in_channels"], + classes=1, + ) + elif self.hyperparams["model"] == "fcn": + self.model = FCN( + in_channels=self.hyperparams["in_channels"], + classes=1, + num_filters=self.hyperparams["num_filters"], + ) + else: + raise ValueError( + f"Model type '{self.hyperparams['model']}' is not valid. " + f"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." + )