Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PixelwiseRegressionTask #1241

Merged
merged 28 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
c21b847
add dense regression task
isaaccorley Apr 13, 2023
6a2c519
fix mypy
isaaccorley Apr 13, 2023
20dd757
refactor and rename to reduce code duplication
isaaccorley Apr 21, 2023
0942fd2
rename
isaaccorley Apr 21, 2023
39700a8
add loss to test configs
isaaccorley Apr 21, 2023
d6fefb6
rename dense to pixelwise
isaaccorley Apr 21, 2023
12180b9
if batch_size=1 then targets will have dims=1 but need dims=2 for loss
isaaccorley Apr 21, 2023
0b81934
add tests'
isaaccorley Apr 21, 2023
ba565e4
fix mypy errors
isaaccorley Apr 21, 2023
ffda89e
get full coverage
isaaccorley Apr 21, 2023
e05a3b5
test for invalid model
isaaccorley Apr 21, 2023
eff0917
add dense regression task
isaaccorley Apr 13, 2023
298d7bb
fix mypy
isaaccorley Apr 13, 2023
7cf0e26
refactor and rename to reduce code duplication
isaaccorley Apr 21, 2023
381a9d9
rename
isaaccorley Apr 21, 2023
a36d74d
add loss to test configs
isaaccorley Apr 21, 2023
2ae4d8b
rename dense to pixelwise
isaaccorley Apr 21, 2023
e6fbbdb
if batch_size=1 then targets will have dims=1 but need dims=2 for loss
isaaccorley Apr 21, 2023
8741d6b
add tests'
isaaccorley Apr 21, 2023
74a14d6
fix mypy errors
isaaccorley Apr 21, 2023
e8b001f
get full coverage
isaaccorley Apr 21, 2023
d71a252
test for invalid model
isaaccorley Apr 21, 2023
fb5c642
Merge branch 'trainers/dense-regression' of github.com:isaaccorley/to…
isaaccorley Apr 22, 2023
e9ca047
Merge branch 'main' into trainers/dense-regression
isaaccorley Apr 24, 2023
e57eb49
Merge branch 'trainers/dense-regression' of github.com:isaaccorley/to…
isaaccorley Apr 24, 2023
5bc1f5d
update tests
isaaccorley Apr 24, 2023
cdcd3c5
Merge branch 'main' into trainers/dense-regression
isaaccorley Apr 24, 2023
b135bdc
100% coverage
isaaccorley Apr 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add tests'
  • Loading branch information
isaaccorley committed Apr 22, 2023
commit 8741d6bae2bf5d7903a5c767bcd1bb7801a61aa0
85 changes: 84 additions & 1 deletion tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,44 @@
from typing import Any, cast

import pytest
import segmentation_models_pytorch as smp
import timm
import torch
import torch.nn as nn
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from lightning.pytorch import LightningDataModule, Trainer
from omegaconf import OmegaConf
from torch.nn.modules import Module
from torchvision.models._api import WeightsEnum

from torchgeo.datamodules import (
COWCCountingDataModule,
InriaAerialImageLabelingDataModule,
MisconfigurationException,
SKIPPDDataModule,
SustainBenchCropYieldDataModule,
TropicalCycloneDataModule,
)
from torchgeo.datasets import TropicalCyclone
from torchgeo.models import get_model_weights, list_models
from torchgeo.trainers import RegressionTask
from torchgeo.trainers import PixelwiseRegressionTask, RegressionTask

from .test_classification import ClassificationTestModel


class PixelwiseRegressionTestModel(Module):
def __init__(self, in_channels: int = 3, classes: int = 1, **kwargs: Any) -> None:
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, self.conv1(x))


class RegressionTestModel(ClassificationTestModel):
def __init__(self, in_chans: int = 3, num_classes: int = 1, **kwargs: Any) -> None:
super().__init__(in_chans=in_chans, num_classes=num_classes)
Expand All @@ -48,6 +63,10 @@ def plot(*args: Any, **kwargs: Any) -> None:
raise ValueError


def create_model(**kwargs: Any) -> Module:
return PixelwiseRegressionTestModel(**kwargs)


class TestRegressionTask:
@pytest.mark.parametrize(
"name,classname",
Expand Down Expand Up @@ -205,3 +224,67 @@ def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None:
match = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=match):
RegressionTask(**model_kwargs)


class TestPixelwiseRegressionTask:
@pytest.mark.parametrize(
"name,classname,batch_size,loss",
[
("inria", InriaAerialImageLabelingDataModule, 1, "mse"),
("inria", InriaAerialImageLabelingDataModule, 2, "mae"),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing regression on Inria binary [0, 1] masks for now since we don't have a readily available pixelwise regression datamodule.

],
)
def test_trainer(
self,
monkeypatch: MonkeyPatch,
name: str,
classname: type[LightningDataModule],
batch_size: int,
loss: str,
fast_dev_run: bool,
) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(dict[str, dict[str, Any]], conf_dict)

# Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"]
datamodule_kwargs["batch_size"] = batch_size
datamodule = classname(**datamodule_kwargs)

# Instantiate model
monkeypatch.setattr(smp, "Unet", create_model)
monkeypatch.setattr(smp, "DeepLabV3Plus", create_model)
model_kwargs = conf_dict["module"]
model_kwargs["loss"] = loss
model = PixelwiseRegressionTask(**model_kwargs)

model.model = PixelwiseRegressionTestModel(in_chans=model_kwargs["in_channels"])

# Instantiate trainer
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)

trainer.fit(model=model, datamodule=datamodule)
try:
trainer.test(model=model, datamodule=datamodule)
except MisconfigurationException:
pass
try:
trainer.predict(model=model, datamodule=datamodule)
except MisconfigurationException:
pass

@pytest.fixture
def model_kwargs(self) -> dict[str, Any]:
return {
"model": "resnet18",
"weights": None,
"num_outputs": 1,
"in_channels": 3,
"loss": "mse",
}
31 changes: 15 additions & 16 deletions torchgeo/trainers/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,14 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor:
batch = args[0]
x = batch["image"]
y = batch[self.target_key]

if y.ndim == 1:
y = y.unsqueeze(dim=1)

y_hat = self(x)

loss = self.loss(y_hat, y)
if y_hat.ndim != y.ndim:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If num_outputs=1 the target variable ground truth is missing the necessary channel dim e.g.

  • (b,) instead of (b, 1)
  • (b, h, w) instead of (b, 1, h, w)

while the output of the models will be:

  • (b, 1)
  • (b, 1, h, w)

y = y.unsqueeze(dim=1)

loss = self.loss(y_hat, y.to(torch.float))
self.log("train_loss", loss) # logging to TensorBoard
self.train_metrics(y_hat, y)
self.train_metrics(y_hat, y.to(torch.float))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cast to float only for loss and metrics in case the plotting expects a different dtype


return loss

Expand All @@ -160,15 +158,14 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None:
batch_idx = args[1]
x = batch["image"]
y = batch[self.target_key]
y_hat = self(x)

if y.ndim == 1:
if y_hat.ndim != y.ndim:
y = y.unsqueeze(dim=1)

y_hat = self(x)

loss = self.loss(y_hat, y)
loss = self.loss(y_hat, y.to(torch.float))
self.log("val_loss", loss)
self.val_metrics(y_hat, y)
self.val_metrics(y_hat, y.to(torch.float))

if (
batch_idx < 10
Expand All @@ -179,6 +176,9 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None:
):
try:
datamodule = self.trainer.datamodule
if self.target_key == "mask":
y = y.squeeze(dim=1)
y_hat = y_hat.squeeze(dim=1)
batch["prediction"] = y_hat
for key in ["image", self.target_key, "prediction"]:
batch[key] = batch[key].cpu()
Expand Down Expand Up @@ -206,15 +206,14 @@ def test_step(self, *args: Any, **kwargs: Any) -> None:
batch = args[0]
x = batch["image"]
y = batch[self.target_key]
y_hat = self(x)

if y.ndim == 1:
if y_hat.ndim != y.ndim:
y = y.unsqueeze(dim=1)

y_hat = self(x)

loss = self.loss(y_hat, y)
loss = self.loss(y_hat, y.to(torch.float))
self.log("test_loss", loss)
self.test_metrics(y_hat, y)
self.test_metrics(y_hat, y.to(torch.float))

def on_test_epoch_end(self) -> None:
"""Logs epoch level test metrics."""
Expand Down