Skip to content

Commit

Permalink
Add PixelwiseRegressionTask (#1241)
Browse files Browse the repository at this point in the history
  • Loading branch information
isaaccorley authored Apr 25, 2023
1 parent 4f714f7 commit 7678627
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 18 deletions.
1 change: 1 addition & 0 deletions tests/conf/cowc_counting.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module:
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
loss: "mse"

datamodule:
_target_: torchgeo.datamodules.COWCCountingDataModule
Expand Down
1 change: 1 addition & 0 deletions tests/conf/cyclone.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module:
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
loss: "mse"

datamodule:
_target_: torchgeo.datamodules.TropicalCycloneDataModule
Expand Down
1 change: 1 addition & 0 deletions tests/conf/skippd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module:
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
loss: "mse"

datamodule:
_target_: torchgeo.datamodules.SKIPPDDataModule
Expand Down
1 change: 1 addition & 0 deletions tests/conf/sustainbench_crop_yield.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module:
in_channels: 9
learning_rate: 1e-3
learning_rate_schedule_patience: 2
loss: "mse"

datamodule:
_target_: torchgeo.datamodules.SustainBenchCropYieldDataModule
Expand Down
106 changes: 104 additions & 2 deletions tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,41 @@

import os
from pathlib import Path
from typing import Any
from typing import Any, cast

import pytest
import segmentation_models_pytorch as smp
import timm
import torch
import torch.nn as nn
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from omegaconf import OmegaConf
from torch.nn.modules import Module
from torchvision.models._api import WeightsEnum

from torchgeo.datamodules import MisconfigurationException, TropicalCycloneDataModule
from torchgeo.datasets import TropicalCyclone
from torchgeo.models import get_model_weights, list_models
from torchgeo.trainers import RegressionTask
from torchgeo.trainers import PixelwiseRegressionTask, RegressionTask

from .test_classification import ClassificationTestModel


class PixelwiseRegressionTestModel(Module):
def __init__(self, in_channels: int = 3, classes: int = 1, **kwargs: Any) -> None:
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, self.conv1(x))


class RegressionTestModel(ClassificationTestModel):
def __init__(self, in_chans: int = 3, num_classes: int = 1, **kwargs: Any) -> None:
super().__init__(in_chans=in_chans, num_classes=num_classes)
Expand All @@ -43,6 +57,10 @@ def plot(*args: Any, **kwargs: Any) -> None:
raise ValueError


def create_model(**kwargs: Any) -> Module:
return PixelwiseRegressionTestModel(**kwargs)


class TestRegressionTask:
@pytest.mark.parametrize(
"name", ["cowc_counting", "cyclone", "sustainbench_crop_yield", "skippd"]
Expand Down Expand Up @@ -85,6 +103,7 @@ def model_kwargs(self) -> dict[str, Any]:
"weights": None,
"num_outputs": 1,
"in_channels": 3,
"loss": "mse",
}

@pytest.fixture(
Expand Down Expand Up @@ -180,3 +199,86 @@ def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None
max_epochs=1,
)
trainer.predict(model=model, datamodule=datamodule)

def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None:
model_kwargs["loss"] = "invalid_loss"
match = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=match):
RegressionTask(**model_kwargs)


class TestPixelwiseRegressionTask:
@pytest.mark.parametrize(
"name,batch_size,loss,model_type",
[
("inria", 1, "mse", "unet"),
("inria", 2, "mae", "deeplabv3+"),
("inria", 1, "mse", "fcn"),
],
)
def test_trainer(
self,
monkeypatch: MonkeyPatch,
name: str,
batch_size: int,
loss: str,
model_type: str,
fast_dev_run: bool,
model_kwargs: dict[str, Any],
) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))

# Instantiate datamodule
conf.datamodule.batch_size = batch_size
datamodule = instantiate(conf.datamodule)

# Instantiate model
monkeypatch.setattr(smp, "Unet", create_model)
monkeypatch.setattr(smp, "DeepLabV3Plus", create_model)
model_kwargs["model"] = model_type
model_kwargs["loss"] = loss

if model_type == "fcn":
model_kwargs["num_filters"] = 2

model = PixelwiseRegressionTask(**model_kwargs)
model.model = PixelwiseRegressionTestModel(
in_channels=model_kwargs["in_channels"]
)

# Instantiate trainer
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)

trainer.fit(model=model, datamodule=datamodule)
try:
trainer.test(model=model, datamodule=datamodule)
except MisconfigurationException:
pass
try:
trainer.predict(model=model, datamodule=datamodule)
except MisconfigurationException:
pass

def test_invalid_model(self, model_kwargs: dict[str, Any]) -> None:
model_kwargs["model"] = "invalid_model"
match = "Model type 'invalid_model' is not valid."
with pytest.raises(ValueError, match=match):
PixelwiseRegressionTask(**model_kwargs)

@pytest.fixture
def model_kwargs(self) -> dict[str, Any]:
return {
"model": "unet",
"backbone": "resnet18",
"weights": None,
"num_outputs": 1,
"in_channels": 3,
"loss": "mse",
"learning_rate": 1e-3,
"learning_rate_schedule_patience": 6,
}
3 changes: 2 additions & 1 deletion torchgeo/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from .byol import BYOLTask
from .classification import ClassificationTask, MultiLabelClassificationTask
from .detection import ObjectDetectionTask
from .regression import RegressionTask
from .regression import PixelwiseRegressionTask, RegressionTask
from .segmentation import SemanticSegmentationTask

__all__ = (
"BYOLTask",
"ClassificationTask",
"MultiLabelClassificationTask",
"ObjectDetectionTask",
"PixelwiseRegressionTask",
"RegressionTask",
"SemanticSegmentationTask",
)
105 changes: 90 additions & 15 deletions torchgeo/trainers/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@
from typing import Any, cast

import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
import timm
import torch
import torch.nn.functional as F
import torch.nn as nn
from lightning.pytorch import LightningModule
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection
from torchvision.models._api import WeightsEnum

from ..datasets import unbind_samples
from ..models import get_weight
from ..models import FCN, get_weight
from . import utils


Expand All @@ -35,8 +36,10 @@ class RegressionTask(LightningModule): # type: ignore[misc]
print(timm.list_models())
"""

def config_task(self) -> None:
"""Configures the task based on kwargs parameters."""
target_key: str = "label"

def config_model(self) -> None:
"""Configures the model based on kwargs parameters."""
# Create model
weights = self.hyperparams["weights"]
self.model = timm.create_model(
Expand All @@ -56,6 +59,21 @@ def config_task(self) -> None:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model = utils.load_state_dict(self.model, state_dict)

def config_task(self) -> None:
"""Configures the task based on kwargs parameters."""
self.config_model()

self.loss: nn.Module
if self.hyperparams["loss"] == "mse":
self.loss = nn.MSELoss()
elif self.hyperparams["loss"] == "mae":
self.loss = nn.L1Loss()
else:
raise ValueError(
f"Loss type '{self.hyperparams['loss']}' is not valid. "
f"Currently, supports 'mse' or 'mae' loss."
)

def __init__(self, **kwargs: Any) -> None:
"""Initialize a new LightningModule for training simple regression models.
Expand All @@ -80,7 +98,11 @@ def __init__(self, **kwargs: Any) -> None:
self.config_task()

self.train_metrics = MetricCollection(
{"RMSE": MeanSquaredError(squared=False), "MAE": MeanAbsoluteError()},
{
"RMSE": MeanSquaredError(squared=False),
"MSE": MeanSquaredError(squared=True),
"MAE": MeanAbsoluteError(),
},
prefix="train_",
)
self.val_metrics = self.train_metrics.clone(prefix="val_")
Expand Down Expand Up @@ -108,13 +130,15 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor:
"""
batch = args[0]
x = batch["image"]
y = batch["label"].view(-1, 1)
y = batch[self.target_key]
y_hat = self(x)

loss = F.mse_loss(y_hat, y)
if y_hat.ndim != y.ndim:
y = y.unsqueeze(dim=1)

loss: Tensor = self.loss(y_hat, y.to(torch.float))
self.log("train_loss", loss) # logging to TensorBoard
self.train_metrics(y_hat, y)
self.train_metrics(y_hat, y.to(torch.float))

return loss

Expand All @@ -133,12 +157,15 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None:
batch = args[0]
batch_idx = args[1]
x = batch["image"]
y = batch["label"].view(-1, 1)
y = batch[self.target_key]
y_hat = self(x)

loss = F.mse_loss(y_hat, y)
if y_hat.ndim != y.ndim:
y = y.unsqueeze(dim=1)

loss = self.loss(y_hat, y.to(torch.float))
self.log("val_loss", loss)
self.val_metrics(y_hat, y)
self.val_metrics(y_hat, y.to(torch.float))

if (
batch_idx < 10
Expand All @@ -149,8 +176,11 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None:
):
try:
datamodule = self.trainer.datamodule
if self.target_key == "mask":
y = y.squeeze(dim=1)
y_hat = y_hat.squeeze(dim=1)
batch["prediction"] = y_hat
for key in ["image", "label", "prediction"]:
for key in ["image", self.target_key, "prediction"]:
batch[key] = batch[key].cpu()
sample = unbind_samples(batch)[0]
fig = datamodule.plot(sample)
Expand All @@ -175,12 +205,15 @@ def test_step(self, *args: Any, **kwargs: Any) -> None:
"""
batch = args[0]
x = batch["image"]
y = batch["label"].view(-1, 1)
y = batch[self.target_key]
y_hat = self(x)

loss = F.mse_loss(y_hat, y)
if y_hat.ndim != y.ndim:
y = y.unsqueeze(dim=1)

loss = self.loss(y_hat, y.to(torch.float))
self.log("test_loss", loss)
self.test_metrics(y_hat, y)
self.test_metrics(y_hat, y.to(torch.float))

def on_test_epoch_end(self) -> None:
"""Logs epoch level test metrics."""
Expand Down Expand Up @@ -219,3 +252,45 @@ def configure_optimizers(self) -> dict[str, Any]:
"monitor": "val_loss",
},
}


class PixelwiseRegressionTask(RegressionTask):
"""LightningModule for pixelwise regression of images.
Supports `Segmentation Models Pytorch
<https://github.com/qubvel/segmentation_models.pytorch>`_
as an architecture choice in combination with any of these
`TIMM backbones <https://smp.readthedocs.io/en/latest/encoders_timm.html>`_.
.. versionadded:: 0.5
"""

target_key: str = "mask"

def config_model(self) -> None:
"""Configures the model based on kwargs parameters."""
if self.hyperparams["model"] == "unet":
self.model = smp.Unet(
encoder_name=self.hyperparams["backbone"],
encoder_weights=self.hyperparams["weights"],
in_channels=self.hyperparams["in_channels"],
classes=1,
)
elif self.hyperparams["model"] == "deeplabv3+":
self.model = smp.DeepLabV3Plus(
encoder_name=self.hyperparams["backbone"],
encoder_weights=self.hyperparams["weights"],
in_channels=self.hyperparams["in_channels"],
classes=1,
)
elif self.hyperparams["model"] == "fcn":
self.model = FCN(
in_channels=self.hyperparams["in_channels"],
classes=1,
num_filters=self.hyperparams["num_filters"],
)
else:
raise ValueError(
f"Model type '{self.hyperparams['model']}' is not valid. "
f"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'."
)

0 comments on commit 7678627

Please sign in to comment.