Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PixelwiseRegressionTask #1241

Merged
merged 28 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
c21b847
add dense regression task
isaaccorley Apr 13, 2023
6a2c519
fix mypy
isaaccorley Apr 13, 2023
20dd757
refactor and rename to reduce code duplication
isaaccorley Apr 21, 2023
0942fd2
rename
isaaccorley Apr 21, 2023
39700a8
add loss to test configs
isaaccorley Apr 21, 2023
d6fefb6
rename dense to pixelwise
isaaccorley Apr 21, 2023
12180b9
if batch_size=1 then targets will have dims=1 but need dims=2 for loss
isaaccorley Apr 21, 2023
0b81934
add tests'
isaaccorley Apr 21, 2023
ba565e4
fix mypy errors
isaaccorley Apr 21, 2023
ffda89e
get full coverage
isaaccorley Apr 21, 2023
e05a3b5
test for invalid model
isaaccorley Apr 21, 2023
eff0917
add dense regression task
isaaccorley Apr 13, 2023
298d7bb
fix mypy
isaaccorley Apr 13, 2023
7cf0e26
refactor and rename to reduce code duplication
isaaccorley Apr 21, 2023
381a9d9
rename
isaaccorley Apr 21, 2023
a36d74d
add loss to test configs
isaaccorley Apr 21, 2023
2ae4d8b
rename dense to pixelwise
isaaccorley Apr 21, 2023
e6fbbdb
if batch_size=1 then targets will have dims=1 but need dims=2 for loss
isaaccorley Apr 21, 2023
8741d6b
add tests'
isaaccorley Apr 21, 2023
74a14d6
fix mypy errors
isaaccorley Apr 21, 2023
e8b001f
get full coverage
isaaccorley Apr 21, 2023
d71a252
test for invalid model
isaaccorley Apr 21, 2023
fb5c642
Merge branch 'trainers/dense-regression' of github.com:isaaccorley/to…
isaaccorley Apr 22, 2023
e9ca047
Merge branch 'main' into trainers/dense-regression
isaaccorley Apr 24, 2023
e57eb49
Merge branch 'trainers/dense-regression' of github.com:isaaccorley/to…
isaaccorley Apr 24, 2023
5bc1f5d
update tests
isaaccorley Apr 24, 2023
cdcd3c5
Merge branch 'main' into trainers/dense-regression
isaaccorley Apr 24, 2023
b135bdc
100% coverage
isaaccorley Apr 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/conf/cowc_counting.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ experiment:
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
loss: "mse"
datamodule:
root: "tests/data/cowc_counting"
download: true
Expand Down
1 change: 1 addition & 0 deletions tests/conf/cyclone.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ experiment:
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
loss: "mse"
datamodule:
root: "tests/data/cyclone"
download: true
Expand Down
1 change: 1 addition & 0 deletions tests/conf/skippd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ experiment:
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
loss: "mse"
datamodule:
root: "tests/data/skippd"
download: true
Expand Down
1 change: 1 addition & 0 deletions tests/conf/sustainbench_crop_yield.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ experiment:
in_channels: 9
learning_rate: 1e-3
learning_rate_schedule_patience: 2
loss: "mse"
datamodule:
root: "tests/data/sustainbench_crop_yield"
download: true
Expand Down
92 changes: 91 additions & 1 deletion tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,44 @@
from typing import Any, cast

import pytest
import segmentation_models_pytorch as smp
import timm
import torch
import torch.nn as nn
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from lightning.pytorch import LightningDataModule, Trainer
from omegaconf import OmegaConf
from torch.nn.modules import Module
from torchvision.models._api import WeightsEnum

from torchgeo.datamodules import (
COWCCountingDataModule,
InriaAerialImageLabelingDataModule,
MisconfigurationException,
SKIPPDDataModule,
SustainBenchCropYieldDataModule,
TropicalCycloneDataModule,
)
from torchgeo.datasets import TropicalCyclone
from torchgeo.models import get_model_weights, list_models
from torchgeo.trainers import RegressionTask
from torchgeo.trainers import PixelwiseRegressionTask, RegressionTask

from .test_classification import ClassificationTestModel


class PixelwiseRegressionTestModel(Module):
def __init__(self, in_channels: int = 3, classes: int = 1, **kwargs: Any) -> None:
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, self.conv1(x))


class RegressionTestModel(ClassificationTestModel):
def __init__(self, in_chans: int = 3, num_classes: int = 1, **kwargs: Any) -> None:
super().__init__(in_chans=in_chans, num_classes=num_classes)
Expand All @@ -48,6 +63,10 @@ def plot(*args: Any, **kwargs: Any) -> None:
raise ValueError


def create_model(**kwargs: Any) -> Module:
return PixelwiseRegressionTestModel(**kwargs)


class TestRegressionTask:
@pytest.mark.parametrize(
"name,classname",
Expand Down Expand Up @@ -103,6 +122,7 @@ def model_kwargs(self) -> dict[str, Any]:
"weights": None,
"num_outputs": 1,
"in_channels": 3,
"loss": "mse",
}

@pytest.fixture(
Expand Down Expand Up @@ -198,3 +218,73 @@ def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None
max_epochs=1,
)
trainer.predict(model=model, datamodule=datamodule)

def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None:
model_kwargs["loss"] = "invalid_loss"
match = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=match):
RegressionTask(**model_kwargs)


class TestPixelwiseRegressionTask:
@pytest.mark.parametrize(
"name,classname,batch_size,loss",
[
("inria", InriaAerialImageLabelingDataModule, 1, "mse"),
("inria", InriaAerialImageLabelingDataModule, 2, "mae"),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing regression on Inria binary [0, 1] masks for now since we don't have a readily available pixelwise regression datamodule.

],
)
def test_trainer(
self,
monkeypatch: MonkeyPatch,
name: str,
classname: type[LightningDataModule],
batch_size: int,
loss: str,
fast_dev_run: bool,
) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(dict[str, dict[str, Any]], conf_dict)

# Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"]
datamodule_kwargs["batch_size"] = batch_size
datamodule = classname(**datamodule_kwargs)

# Instantiate model
monkeypatch.setattr(smp, "Unet", create_model)
monkeypatch.setattr(smp, "DeepLabV3Plus", create_model)
model_kwargs = conf_dict["module"]
model_kwargs["loss"] = loss
model = PixelwiseRegressionTask(**model_kwargs)

model.model = PixelwiseRegressionTestModel(in_chans=model_kwargs["in_channels"])

# Instantiate trainer
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)

trainer.fit(model=model, datamodule=datamodule)
try:
trainer.test(model=model, datamodule=datamodule)
except MisconfigurationException:
pass
try:
trainer.predict(model=model, datamodule=datamodule)
except MisconfigurationException:
pass

@pytest.fixture
def model_kwargs(self) -> dict[str, Any]:
return {
"model": "resnet18",
"weights": None,
"num_outputs": 1,
"in_channels": 3,
"loss": "mse",
}
3 changes: 2 additions & 1 deletion torchgeo/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from .byol import BYOLTask
from .classification import ClassificationTask, MultiLabelClassificationTask
from .detection import ObjectDetectionTask
from .regression import RegressionTask
from .regression import PixelwiseRegressionTask, RegressionTask
from .segmentation import SemanticSegmentationTask

__all__ = (
"BYOLTask",
"ClassificationTask",
"MultiLabelClassificationTask",
"ObjectDetectionTask",
"PixelwiseRegressionTask",
"RegressionTask",
"SemanticSegmentationTask",
)
105 changes: 90 additions & 15 deletions torchgeo/trainers/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@
from typing import Any, cast

import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
import timm
import torch
import torch.nn.functional as F
import torch.nn as nn
from lightning.pytorch import LightningModule
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection
from torchvision.models._api import WeightsEnum

from ..datasets import unbind_samples
from ..models import get_weight
from ..models import FCN, get_weight
from . import utils


Expand All @@ -35,8 +36,10 @@ class RegressionTask(LightningModule): # type: ignore[misc]
print(timm.list_models())
"""

def config_task(self) -> None:
"""Configures the task based on kwargs parameters."""
target_key: str = "label"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows to not have to duplicate all the train/val/test steps just to change label to mask. Let me know if you have any other suggestions.


def config_model(self) -> None:
"""Configures the model based on kwargs parameters."""
# Create model
weights = self.hyperparams["weights"]
self.model = timm.create_model(
Expand All @@ -56,6 +59,21 @@ def config_task(self) -> None:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model = utils.load_state_dict(self.model, state_dict)

def config_task(self) -> None:
"""Configures the task based on kwargs parameters."""
self.config_model()

self.loss: nn.Module
if self.hyperparams["loss"] == "mse":
self.loss = nn.MSELoss()
elif self.hyperparams["loss"] == "mae":
self.loss = nn.L1Loss()
else:
raise ValueError(
f"Loss type '{self.hyperparams['loss']}' is not valid. "
f"Currently, supports 'mse' or 'mae' loss."
)

def __init__(self, **kwargs: Any) -> None:
"""Initialize a new LightningModule for training simple regression models.

Expand All @@ -80,7 +98,11 @@ def __init__(self, **kwargs: Any) -> None:
self.config_task()

self.train_metrics = MetricCollection(
{"RMSE": MeanSquaredError(squared=False), "MAE": MeanAbsoluteError()},
{
"RMSE": MeanSquaredError(squared=False),
"MSE": MeanSquaredError(squared=True),
"MAE": MeanAbsoluteError(),
},
prefix="train_",
)
self.val_metrics = self.train_metrics.clone(prefix="val_")
Expand Down Expand Up @@ -108,13 +130,15 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor:
"""
batch = args[0]
x = batch["image"]
y = batch["label"].view(-1, 1)
y = batch[self.target_key]
y_hat = self(x)

loss = F.mse_loss(y_hat, y)
if y_hat.ndim != y.ndim:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If num_outputs=1 the target variable ground truth is missing the necessary channel dim e.g.

  • (b,) instead of (b, 1)
  • (b, h, w) instead of (b, 1, h, w)

while the output of the models will be:

  • (b, 1)
  • (b, 1, h, w)

y = y.unsqueeze(dim=1)

loss = self.loss(y_hat, y.to(torch.float))
self.log("train_loss", loss) # logging to TensorBoard
self.train_metrics(y_hat, y)
self.train_metrics(y_hat, y.to(torch.float))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cast to float only for loss and metrics in case the plotting expects a different dtype


return loss

Expand All @@ -133,12 +157,15 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None:
batch = args[0]
batch_idx = args[1]
x = batch["image"]
y = batch["label"].view(-1, 1)
y = batch[self.target_key]
y_hat = self(x)

loss = F.mse_loss(y_hat, y)
if y_hat.ndim != y.ndim:
y = y.unsqueeze(dim=1)

loss = self.loss(y_hat, y.to(torch.float))
self.log("val_loss", loss)
self.val_metrics(y_hat, y)
self.val_metrics(y_hat, y.to(torch.float))

if (
batch_idx < 10
Expand All @@ -149,8 +176,11 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None:
):
try:
datamodule = self.trainer.datamodule
if self.target_key == "mask":
y = y.squeeze(dim=1)
y_hat = y_hat.squeeze(dim=1)
batch["prediction"] = y_hat
for key in ["image", "label", "prediction"]:
for key in ["image", self.target_key, "prediction"]:
batch[key] = batch[key].cpu()
sample = unbind_samples(batch)[0]
fig = datamodule.plot(sample)
Expand All @@ -175,12 +205,15 @@ def test_step(self, *args: Any, **kwargs: Any) -> None:
"""
batch = args[0]
x = batch["image"]
y = batch["label"].view(-1, 1)
y = batch[self.target_key]
y_hat = self(x)

loss = F.mse_loss(y_hat, y)
if y_hat.ndim != y.ndim:
y = y.unsqueeze(dim=1)

loss = self.loss(y_hat, y.to(torch.float))
self.log("test_loss", loss)
self.test_metrics(y_hat, y)
self.test_metrics(y_hat, y.to(torch.float))

def on_test_epoch_end(self) -> None:
"""Logs epoch level test metrics."""
Expand Down Expand Up @@ -219,3 +252,45 @@ def configure_optimizers(self) -> dict[str, Any]:
"monitor": "val_loss",
},
}


class PixelwiseRegressionTask(RegressionTask):
"""LightningModule for pixelwise regression of images.

Supports `Segmentation Models Pytorch
<https://github.com/qubvel/segmentation_models.pytorch>`_
as an architecture choice in combination with any of these
`TIMM backbones <https://smp.readthedocs.io/en/latest/encoders_timm.html>`_.

.. versionadded:: 0.5
"""

target_key: str = "mask"

def config_model(self) -> None:
"""Configures the model based on kwargs parameters."""
if self.hyperparams["model"] == "unet":
self.model = smp.Unet(
encoder_name=self.hyperparams["backbone"],
encoder_weights=self.hyperparams["weights"],
in_channels=self.hyperparams["in_channels"],
classes=1,
)
elif self.hyperparams["model"] == "deeplabv3+":
self.model = smp.DeepLabV3Plus(
encoder_name=self.hyperparams["backbone"],
encoder_weights=self.hyperparams["weights"],
in_channels=self.hyperparams["in_channels"],
classes=1,
)
elif self.hyperparams["model"] == "fcn":
self.model = FCN(
in_channels=self.hyperparams["in_channels"],
classes=1,
num_filters=self.hyperparams["num_filters"],
)
else:
raise ValueError(
f"Model type '{self.hyperparams['model']}' is not valid. "
f"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'."
)