From fb488f4fb17a9cced8e32921a1d9bb2ae1bdd26e Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Wed, 25 Jan 2023 23:32:51 +0000 Subject: [PATCH 1/8] add pretrained weights loading for the segmentation encoder --- torchgeo/trainers/segmentation.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index fed2c596cc9..a6b841ac32e 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -3,6 +3,7 @@ """Segmentation tasks.""" +import os import warnings from typing import Any, cast @@ -15,9 +16,11 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau from torchmetrics import MetricCollection from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex +from torchvision.models._api import WeightsEnum from ..datasets.utils import unbind_samples -from ..models import FCN +from ..models import FCN, get_weight +from . import utils class SemanticSegmentationTask(LightningModule): # type: ignore[misc] @@ -31,17 +34,19 @@ class SemanticSegmentationTask(LightningModule): # type: ignore[misc] def config_task(self) -> None: """Configures the task based on kwargs parameters passed to the constructor.""" + weights = self.hyperparams["weights"] + if self.hyperparams["model"] == "unet": self.model = smp.Unet( encoder_name=self.hyperparams["backbone"], - encoder_weights=self.hyperparams["weights"], + encoder_weights="imagenet" if weights is True else None, in_channels=self.hyperparams["in_channels"], classes=self.hyperparams["num_classes"], ) elif self.hyperparams["model"] == "deeplabv3+": self.model = smp.DeepLabV3Plus( encoder_name=self.hyperparams["backbone"], - encoder_weights=self.hyperparams["weights"], + encoder_weights="imagenet" if weights is True else None, in_channels=self.hyperparams["in_channels"], classes=self.hyperparams["num_classes"], ) @@ -80,6 +85,16 @@ def config_task(self) -> None: f"Currently, supports 'ce', 'jaccard' or 'focal' loss." ) + if self.hyperparams["model"] != "fcn": + if weights and weights is not True: + if isinstance(weights, WeightsEnum): + state_dict = weights.get_state_dict(progress=True) + elif os.path.exists(weights): + _, state_dict = utils.extract_backbone(weights) + else: + state_dict = get_weight(weights).get_state_dict(progress=True) + self.model.encoder = utils.load_state_dict(self.model, state_dict) + def __init__(self, **kwargs: Any) -> None: """Initialize the LightningModule with a model and loss function. From 29e89ffd532c1f05c5c8de702c481ec83b878f5b Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Tue, 25 Apr 2023 22:05:32 -0700 Subject: [PATCH 2/8] Updating config files to use new pretrained arg style --- README.md | 2 +- conf/etci2021.yaml | 2 +- conf/inria.yaml | 2 +- conf/landcoverai.yaml | 2 +- conf/naipchesapeake.yaml | 2 +- conf/spacenet1.yaml | 2 +- tests/conf/inria.yaml | 2 +- torchgeo/trainers/segmentation.py | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 13711cde9c5..8a5e59d114e 100644 --- a/README.md +++ b/README.md @@ -128,7 +128,7 @@ datamodule = InriaAerialImageLabelingDataModule(root="...", batch_size=64, num_w task = SemanticSegmentationTask( model="unet", backbone="resnet50", - weights="imagenet", + weights=True, in_channels=3, num_classes=2, loss="ce", diff --git a/conf/etci2021.yaml b/conf/etci2021.yaml index 5f06393cfd5..e993b8ac628 100644 --- a/conf/etci2021.yaml +++ b/conf/etci2021.yaml @@ -3,7 +3,7 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: "imagenet" + weights: true learning_rate: 1e-3 learning_rate_schedule_patience: 6 in_channels: 6 diff --git a/conf/inria.yaml b/conf/inria.yaml index 32083a3c6b9..bbf73669a6a 100644 --- a/conf/inria.yaml +++ b/conf/inria.yaml @@ -3,7 +3,7 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: "imagenet" + weights: true learning_rate: 1e-3 learning_rate_schedule_patience: 6 in_channels: 3 diff --git a/conf/landcoverai.yaml b/conf/landcoverai.yaml index c3742581972..f70667fe056 100644 --- a/conf/landcoverai.yaml +++ b/conf/landcoverai.yaml @@ -3,7 +3,7 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: "imagenet" + weights: true learning_rate: 1e-3 learning_rate_schedule_patience: 6 in_channels: 3 diff --git a/conf/naipchesapeake.yaml b/conf/naipchesapeake.yaml index 6b562778fe0..94f6cafcab6 100644 --- a/conf/naipchesapeake.yaml +++ b/conf/naipchesapeake.yaml @@ -3,7 +3,7 @@ module: loss: "ce" model: "deeplabv3+" backbone: "resnet34" - weights: "imagenet" + weights: true learning_rate: 1e-3 learning_rate_schedule_patience: 2 in_channels: 4 diff --git a/conf/spacenet1.yaml b/conf/spacenet1.yaml index c7a236f4634..82955319a57 100644 --- a/conf/spacenet1.yaml +++ b/conf/spacenet1.yaml @@ -3,7 +3,7 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: "imagenet" + weights: true learning_rate: 1e-3 learning_rate_schedule_patience: 6 in_channels: 3 diff --git a/tests/conf/inria.yaml b/tests/conf/inria.yaml index c6296802f9e..7a47124bee5 100644 --- a/tests/conf/inria.yaml +++ b/tests/conf/inria.yaml @@ -3,7 +3,7 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: "imagenet" + weights: true learning_rate: 1e-3 learning_rate_schedule_patience: 6 in_channels: 3 diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index a6b841ac32e..3b989d73919 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -101,7 +101,7 @@ def __init__(self, **kwargs: Any) -> None: Keyword Args: model: Name of the segmentation model type to use backbone: Name of the timm backbone to use - weights: None or "imagenet" to use imagenet pretrained weights in + weights: None or True to use imagenet pretrained weights in the backbone in_channels: Number of channels in input image num_classes: Number of semantic classes to predict From bcc64f3e2a88df3c0cea2c4e444dd03ff9f15b9d Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Wed, 3 May 2023 17:16:00 +0000 Subject: [PATCH 3/8] fix loading weights enum to encoder --- torchgeo/trainers/segmentation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 61e5d3276b9..48b754a7400 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -89,11 +89,13 @@ def config_task(self) -> None: if weights and weights is not True: if isinstance(weights, WeightsEnum): state_dict = weights.get_state_dict(progress=True) + self.model.encoder.load_state_dict(state_dict) elif os.path.exists(weights): _, state_dict = utils.extract_backbone(weights) + self.model.encoder = utils.load_state_dict(self.model, state_dict) else: state_dict = get_weight(weights).get_state_dict(progress=True) - self.model.encoder = utils.load_state_dict(self.model, state_dict) + self.model.encoder = utils.load_state_dict(self.model, state_dict) # Freeze backbone if self.hyperparams.get("freeze_backbone", False) and self.hyperparams[ From 02bcb657d9d3b17a3417cf4bae27292f4af4d89c Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Wed, 3 May 2023 18:05:42 +0000 Subject: [PATCH 4/8] I have no idea what I'm doing with these tests --- tests/trainers/test_segmentation.py | 61 +++++++++++++++++++++++++++++ torchgeo/trainers/segmentation.py | 8 +++- 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index e2ac64fef84..5c4850dacd6 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -2,20 +2,26 @@ # Licensed under the MIT License. import os +from pathlib import Path from typing import Any, cast import pytest import segmentation_models_pytorch as smp +import timm import torch import torch.nn as nn +import torchvision +from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from hydra.utils import instantiate from lightning.pytorch import Trainer from omegaconf import OmegaConf from torch.nn.modules import Module +from torchvision.models._api import WeightsEnum from torchgeo.datamodules import MisconfigurationException, SEN12MSDataModule from torchgeo.datasets import LandCoverAI +from torchgeo.models import get_model_weights, list_models from torchgeo.trainers import SemanticSegmentationTask @@ -34,6 +40,11 @@ def create_model(**kwargs: Any) -> Module: return SegmentationTestModel(**kwargs) +def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: + state_dict: dict[str, Any] = torch.load(url) + return state_dict + + def plot(*args: Any, **kwargs: Any) -> None: raise ValueError @@ -111,6 +122,56 @@ def model_kwargs(self) -> dict[Any, Any]: "ignore_index": 0, } + @pytest.fixture( + params=[ + weights + for model in list_models() + for weights in get_model_weights(model) + if "resnet" in weights.meta["model"] + ] + ) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param + + @pytest.fixture + def mocked_weights( + self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + ) -> WeightsEnum: + path = tmp_path / f"{weights}.pth" + model = timm.create_model( + weights.meta["model"], in_chans=weights.meta["in_chans"] + ) + torch.save(model.state_dict(), path) + try: + monkeypatch.setattr(weights.value, "url", str(path)) + except AttributeError: + monkeypatch.setattr(weights, "url", str(path)) + monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + return weights + + def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None: + model_kwargs["weights"] = checkpoint + with pytest.warns(UserWarning): + SemanticSegmentationTask(**model_kwargs) + + def test_weight_enum( + self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum + ) -> None: + model_kwargs["backbone"] = mocked_weights.meta["model"] + model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] + model_kwargs["weights"] = mocked_weights + with pytest.warns(UserWarning): + SemanticSegmentationTask(**model_kwargs) + + def test_weight_str( + self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum + ) -> None: + model_kwargs["backbone"] = mocked_weights.meta["model"] + model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] + model_kwargs["weights"] = str(mocked_weights) + with pytest.warns(UserWarning): + SemanticSegmentationTask(**model_kwargs) + def test_invalid_model(self, model_kwargs: dict[Any, Any]) -> None: model_kwargs["model"] = "invalid_model" match = "Model type 'invalid_model' is not valid." diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 48b754a7400..af1d0097e84 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -92,10 +92,14 @@ def config_task(self) -> None: self.model.encoder.load_state_dict(state_dict) elif os.path.exists(weights): _, state_dict = utils.extract_backbone(weights) - self.model.encoder = utils.load_state_dict(self.model, state_dict) + self.model.encoder = utils.load_state_dict( + self.model.encoder, state_dict + ) else: state_dict = get_weight(weights).get_state_dict(progress=True) - self.model.encoder = utils.load_state_dict(self.model, state_dict) + self.model.encoder = utils.load_state_dict( + self.model.encoder, state_dict + ) # Freeze backbone if self.hyperparams.get("freeze_backbone", False) and self.hyperparams[ From d8fc8af422f1770bc5f4351421fd4af70d0c8928 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Wed, 3 May 2023 19:17:44 +0000 Subject: [PATCH 5/8] tests passing --- tests/trainers/test_segmentation.py | 6 ------ torchgeo/trainers/segmentation.py | 9 ++++----- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 5c4850dacd6..24babfa32f7 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -151,8 +151,6 @@ def mocked_weights( def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None: model_kwargs["weights"] = checkpoint - with pytest.warns(UserWarning): - SemanticSegmentationTask(**model_kwargs) def test_weight_enum( self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum @@ -160,8 +158,6 @@ def test_weight_enum( model_kwargs["backbone"] = mocked_weights.meta["model"] model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] model_kwargs["weights"] = mocked_weights - with pytest.warns(UserWarning): - SemanticSegmentationTask(**model_kwargs) def test_weight_str( self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum @@ -169,8 +165,6 @@ def test_weight_str( model_kwargs["backbone"] = mocked_weights.meta["model"] model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] model_kwargs["weights"] = str(mocked_weights) - with pytest.warns(UserWarning): - SemanticSegmentationTask(**model_kwargs) def test_invalid_model(self, model_kwargs: dict[Any, Any]) -> None: model_kwargs["model"] = "invalid_model" diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index af1d0097e84..a01719d8ccb 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -97,9 +97,7 @@ def config_task(self) -> None: ) else: state_dict = get_weight(weights).get_state_dict(progress=True) - self.model.encoder = utils.load_state_dict( - self.model.encoder, state_dict - ) + self.model.encoder.load_state_dict(state_dict) # Freeze backbone if self.hyperparams.get("freeze_backbone", False) and self.hyperparams[ @@ -121,8 +119,9 @@ def __init__(self, **kwargs: Any) -> None: Keyword Args: model: Name of the segmentation model type to use backbone: Name of the timm backbone to use - weights: None or True to use imagenet pretrained weights in - the backbone + weights: Either a weight enum, the string representation of a weight enum, + True for ImageNet weights, False or None for random weights, + or the path to a saved model state dict. in_channels: Number of channels in input image num_classes: Number of semantic classes to predict loss: Name of the loss function, currently supports From 3a576bb371d53b483910932f3365235482a2bc3c Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Wed, 3 May 2023 19:28:24 +0000 Subject: [PATCH 6/8] update docstring --- torchgeo/trainers/segmentation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index a01719d8ccb..9ba773a49a3 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -121,7 +121,8 @@ def __init__(self, **kwargs: Any) -> None: backbone: Name of the timm backbone to use weights: Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, - or the path to a saved model state dict. + or the path to a saved model state dict. FCN model does not support + pretrained weights. Pretrained ViT weight enums are not supported yet. in_channels: Number of channels in input image num_classes: Number of semantic classes to predict loss: Name of the loss function, currently supports @@ -147,6 +148,9 @@ class and used with 'ce' loss The *class_weights*, *freeze_backbone*, and *freeze_decoder* parameters. + .. versionchanged:: 0.5 + The *weights* parameter supports WeightEnums and checkpoint paths. + """ super().__init__() From dd4f56bc116edef2bbc9204f7755e0a510aa926f Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Wed, 3 May 2023 19:31:38 +0000 Subject: [PATCH 7/8] add the tests back in dummy --- tests/trainers/test_segmentation.py | 3 +++ torchgeo/trainers/segmentation.py | 6 +----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 24babfa32f7..c04aa99e83a 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -151,6 +151,7 @@ def mocked_weights( def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None: model_kwargs["weights"] = checkpoint + SemanticSegmentationTask(**model_kwargs) def test_weight_enum( self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum @@ -158,6 +159,7 @@ def test_weight_enum( model_kwargs["backbone"] = mocked_weights.meta["model"] model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] model_kwargs["weights"] = mocked_weights + SemanticSegmentationTask(**model_kwargs) def test_weight_str( self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum @@ -165,6 +167,7 @@ def test_weight_str( model_kwargs["backbone"] = mocked_weights.meta["model"] model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] model_kwargs["weights"] = str(mocked_weights) + SemanticSegmentationTask(**model_kwargs) def test_invalid_model(self, model_kwargs: dict[Any, Any]) -> None: model_kwargs["model"] = "invalid_model" diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 9ba773a49a3..575a9b2af6d 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -89,15 +89,11 @@ def config_task(self) -> None: if weights and weights is not True: if isinstance(weights, WeightsEnum): state_dict = weights.get_state_dict(progress=True) - self.model.encoder.load_state_dict(state_dict) elif os.path.exists(weights): _, state_dict = utils.extract_backbone(weights) - self.model.encoder = utils.load_state_dict( - self.model.encoder, state_dict - ) else: state_dict = get_weight(weights).get_state_dict(progress=True) - self.model.encoder.load_state_dict(state_dict) + self.model.encoder.load_state_dict(state_dict) # Freeze backbone if self.hyperparams.get("freeze_backbone", False) and self.hyperparams[ From 3247af7e28a1ec6cac11c04c5d9260c5b3c1a19e Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Wed, 3 May 2023 20:47:53 +0000 Subject: [PATCH 8/8] add tests --- tests/conf/inria.yaml | 2 +- tests/trainers/test_segmentation.py | 18 ++++++++++++++++++ torchgeo/trainers/segmentation.py | 2 +- 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/conf/inria.yaml b/tests/conf/inria.yaml index 7a47124bee5..df4f4043fc4 100644 --- a/tests/conf/inria.yaml +++ b/tests/conf/inria.yaml @@ -3,7 +3,7 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: true + weights: null learning_rate: 1e-3 learning_rate_schedule_patience: 6 in_channels: 3 diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index c04aa99e83a..7fd7badcd34 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -169,6 +169,24 @@ def test_weight_str( model_kwargs["weights"] = str(mocked_weights) SemanticSegmentationTask(**model_kwargs) + @pytest.mark.slow + def test_weight_enum_download( + self, model_kwargs: dict[str, Any], weights: WeightsEnum + ) -> None: + model_kwargs["backbone"] = weights.meta["model"] + model_kwargs["in_channels"] = weights.meta["in_chans"] + model_kwargs["weights"] = weights + SemanticSegmentationTask(**model_kwargs) + + @pytest.mark.slow + def test_weight_str_download( + self, model_kwargs: dict[str, Any], weights: WeightsEnum + ) -> None: + model_kwargs["backbone"] = weights.meta["model"] + model_kwargs["in_channels"] = weights.meta["in_chans"] + model_kwargs["weights"] = str(weights) + SemanticSegmentationTask(**model_kwargs) + def test_invalid_model(self, model_kwargs: dict[Any, Any]) -> None: model_kwargs["model"] = "invalid_model" match = "Model type 'invalid_model' is not valid." diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 575a9b2af6d..883110fb8d1 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -145,7 +145,7 @@ class and used with 'ce' loss and *freeze_decoder* parameters. .. versionchanged:: 0.5 - The *weights* parameter supports WeightEnums and checkpoint paths. + The *weights* parameter now supports WeightEnums and checkpoint paths. """ super().__init__()