Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmentation Pretrained Weights #1046

2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ datamodule = InriaAerialImageLabelingDataModule(root="...", batch_size=64, num_w
task = SemanticSegmentationTask(
model="unet",
backbone="resnet50",
weights="imagenet",
weights=True,
in_channels=3,
num_classes=2,
loss="ce",
Expand Down
2 changes: 1 addition & 1 deletion conf/etci2021.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 6
Expand Down
2 changes: 1 addition & 1 deletion conf/inria.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
Expand Down
2 changes: 1 addition & 1 deletion conf/landcoverai.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
Expand Down
2 changes: 1 addition & 1 deletion conf/naipchesapeake.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "deeplabv3+"
backbone: "resnet34"
weights: "imagenet"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 2
in_channels: 4
Expand Down
2 changes: 1 addition & 1 deletion conf/spacenet1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/inria.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
weights: true
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
Expand Down
23 changes: 19 additions & 4 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Segmentation tasks."""

import os
import warnings
from typing import Any, cast

Expand All @@ -15,9 +16,11 @@
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex
from torchvision.models._api import WeightsEnum

from ..datasets.utils import unbind_samples
from ..models import FCN
from ..models import FCN, get_weight
from . import utils


class SemanticSegmentationTask(LightningModule): # type: ignore[misc]
Expand All @@ -31,17 +34,19 @@ class SemanticSegmentationTask(LightningModule): # type: ignore[misc]

def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor."""
weights = self.hyperparams["weights"]

if self.hyperparams["model"] == "unet":
self.model = smp.Unet(
encoder_name=self.hyperparams["backbone"],
encoder_weights=self.hyperparams["weights"],
encoder_weights="imagenet" if weights is True else None,
in_channels=self.hyperparams["in_channels"],
classes=self.hyperparams["num_classes"],
)
elif self.hyperparams["model"] == "deeplabv3+":
self.model = smp.DeepLabV3Plus(
encoder_name=self.hyperparams["backbone"],
encoder_weights=self.hyperparams["weights"],
encoder_weights="imagenet" if weights is True else None,
in_channels=self.hyperparams["in_channels"],
classes=self.hyperparams["num_classes"],
)
Expand Down Expand Up @@ -80,13 +85,23 @@ def config_task(self) -> None:
f"Currently, supports 'ce', 'jaccard' or 'focal' loss."
)

if self.hyperparams["model"] != "fcn":
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model.encoder = utils.load_state_dict(self.model, state_dict)

def __init__(self, **kwargs: Any) -> None:
"""Initialize the LightningModule with a model and loss function.

Keyword Args:
model: Name of the segmentation model type to use
backbone: Name of the timm backbone to use
weights: None or "imagenet" to use imagenet pretrained weights in
weights: None or True to use imagenet pretrained weights in
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved
the backbone
in_channels: Number of channels in input image
num_classes: Number of semantic classes to predict
Expand Down