-
Notifications
You must be signed in to change notification settings - Fork 391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add PixelwiseRegressionTask #1241
Changes from 8 commits
c21b847
6a2c519
20dd757
0942fd2
39700a8
d6fefb6
12180b9
0b81934
ba565e4
ffda89e
e05a3b5
eff0917
298d7bb
7cf0e26
381a9d9
a36d74d
2ae4d8b
e6fbbdb
8741d6b
74a14d6
e8b001f
d71a252
fb5c642
e9ca047
e57eb49
5bc1f5d
cdcd3c5
b135bdc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,17 +7,18 @@ | |
from typing import Any, cast | ||
|
||
import matplotlib.pyplot as plt | ||
import segmentation_models_pytorch as smp | ||
import timm | ||
import torch | ||
import torch.nn.functional as F | ||
import torch.nn as nn | ||
from lightning.pytorch import LightningModule | ||
from torch import Tensor | ||
from torch.optim.lr_scheduler import ReduceLROnPlateau | ||
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection | ||
from torchvision.models._api import WeightsEnum | ||
|
||
from ..datasets import unbind_samples | ||
from ..models import get_weight | ||
from ..models import FCN, get_weight | ||
from . import utils | ||
|
||
|
||
|
@@ -35,8 +36,10 @@ class RegressionTask(LightningModule): # type: ignore[misc] | |
print(timm.list_models()) | ||
""" | ||
|
||
def config_task(self) -> None: | ||
"""Configures the task based on kwargs parameters.""" | ||
target_key: str = "label" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This allows to not have to duplicate all the train/val/test steps just to change |
||
|
||
def config_model(self) -> None: | ||
"""Configures the model based on kwargs parameters.""" | ||
# Create model | ||
weights = self.hyperparams["weights"] | ||
self.model = timm.create_model( | ||
|
@@ -56,6 +59,21 @@ def config_task(self) -> None: | |
state_dict = get_weight(weights).get_state_dict(progress=True) | ||
self.model = utils.load_state_dict(self.model, state_dict) | ||
|
||
def config_task(self) -> None: | ||
"""Configures the task based on kwargs parameters.""" | ||
self.config_model() | ||
|
||
self.loss: nn.Module | ||
if self.hyperparams["loss"] == "mse": | ||
self.loss = nn.MSELoss() | ||
elif self.hyperparams["loss"] == "mae": | ||
self.loss = nn.L1Loss() | ||
else: | ||
raise ValueError( | ||
f"Loss type '{self.hyperparams['loss']}' is not valid. " | ||
f"Currently, supports 'mse' or 'mae' loss." | ||
) | ||
|
||
def __init__(self, **kwargs: Any) -> None: | ||
"""Initialize a new LightningModule for training simple regression models. | ||
|
||
|
@@ -80,7 +98,11 @@ def __init__(self, **kwargs: Any) -> None: | |
self.config_task() | ||
|
||
self.train_metrics = MetricCollection( | ||
{"RMSE": MeanSquaredError(squared=False), "MAE": MeanAbsoluteError()}, | ||
{ | ||
"RMSE": MeanSquaredError(squared=False), | ||
"MSE": MeanSquaredError(squared=True), | ||
"MAE": MeanAbsoluteError(), | ||
}, | ||
prefix="train_", | ||
) | ||
self.val_metrics = self.train_metrics.clone(prefix="val_") | ||
|
@@ -108,13 +130,15 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor: | |
""" | ||
batch = args[0] | ||
x = batch["image"] | ||
y = batch["label"].view(-1, 1) | ||
y = batch[self.target_key] | ||
y_hat = self(x) | ||
|
||
loss = F.mse_loss(y_hat, y) | ||
if y_hat.ndim != y.ndim: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If
while the output of the models will be:
|
||
y = y.unsqueeze(dim=1) | ||
|
||
loss = self.loss(y_hat, y.to(torch.float)) | ||
self.log("train_loss", loss) # logging to TensorBoard | ||
self.train_metrics(y_hat, y) | ||
self.train_metrics(y_hat, y.to(torch.float)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cast to float only for loss and metrics in case the plotting expects a different dtype |
||
|
||
return loss | ||
|
||
|
@@ -133,12 +157,15 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: | |
batch = args[0] | ||
batch_idx = args[1] | ||
x = batch["image"] | ||
y = batch["label"].view(-1, 1) | ||
y = batch[self.target_key] | ||
y_hat = self(x) | ||
|
||
loss = F.mse_loss(y_hat, y) | ||
if y_hat.ndim != y.ndim: | ||
y = y.unsqueeze(dim=1) | ||
|
||
loss = self.loss(y_hat, y.to(torch.float)) | ||
self.log("val_loss", loss) | ||
self.val_metrics(y_hat, y) | ||
self.val_metrics(y_hat, y.to(torch.float)) | ||
|
||
if ( | ||
batch_idx < 10 | ||
|
@@ -149,8 +176,11 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: | |
): | ||
try: | ||
datamodule = self.trainer.datamodule | ||
if self.target_key == "mask": | ||
y = y.squeeze(dim=1) | ||
y_hat = y_hat.squeeze(dim=1) | ||
batch["prediction"] = y_hat | ||
for key in ["image", "label", "prediction"]: | ||
for key in ["image", self.target_key, "prediction"]: | ||
batch[key] = batch[key].cpu() | ||
sample = unbind_samples(batch)[0] | ||
fig = datamodule.plot(sample) | ||
|
@@ -175,12 +205,15 @@ def test_step(self, *args: Any, **kwargs: Any) -> None: | |
""" | ||
batch = args[0] | ||
x = batch["image"] | ||
y = batch["label"].view(-1, 1) | ||
y = batch[self.target_key] | ||
y_hat = self(x) | ||
|
||
loss = F.mse_loss(y_hat, y) | ||
if y_hat.ndim != y.ndim: | ||
y = y.unsqueeze(dim=1) | ||
|
||
loss = self.loss(y_hat, y.to(torch.float)) | ||
self.log("test_loss", loss) | ||
self.test_metrics(y_hat, y) | ||
self.test_metrics(y_hat, y.to(torch.float)) | ||
|
||
def on_test_epoch_end(self) -> None: | ||
"""Logs epoch level test metrics.""" | ||
|
@@ -219,3 +252,45 @@ def configure_optimizers(self) -> dict[str, Any]: | |
"monitor": "val_loss", | ||
}, | ||
} | ||
|
||
|
||
class PixelwiseRegressionTask(RegressionTask): | ||
"""LightningModule for pixelwise regression of images. | ||
|
||
Supports `Segmentation Models Pytorch | ||
<https://github.com/qubvel/segmentation_models.pytorch>`_ | ||
as an architecture choice in combination with any of these | ||
`TIMM backbones <https://smp.readthedocs.io/en/latest/encoders_timm.html>`_. | ||
|
||
.. versionadded:: 0.5 | ||
""" | ||
|
||
target_key: str = "mask" | ||
|
||
def config_model(self) -> None: | ||
"""Configures the model based on kwargs parameters.""" | ||
if self.hyperparams["model"] == "unet": | ||
self.model = smp.Unet( | ||
encoder_name=self.hyperparams["backbone"], | ||
encoder_weights=self.hyperparams["weights"], | ||
in_channels=self.hyperparams["in_channels"], | ||
classes=1, | ||
) | ||
elif self.hyperparams["model"] == "deeplabv3+": | ||
self.model = smp.DeepLabV3Plus( | ||
encoder_name=self.hyperparams["backbone"], | ||
encoder_weights=self.hyperparams["weights"], | ||
in_channels=self.hyperparams["in_channels"], | ||
classes=1, | ||
) | ||
elif self.hyperparams["model"] == "fcn": | ||
self.model = FCN( | ||
in_channels=self.hyperparams["in_channels"], | ||
classes=1, | ||
num_filters=self.hyperparams["num_filters"], | ||
) | ||
else: | ||
raise ValueError( | ||
f"Model type '{self.hyperparams['model']}' is not valid. " | ||
f"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Testing regression on Inria binary [0, 1] masks for now since we don't have a readily available pixelwise regression datamodule.