From 7ad1f0c102e824e3a352352177968ddac4dcff03 Mon Sep 17 00:00:00 2001 From: Nils Lehmann <35272119+nilsleh@users.noreply.github.com> Date: Mon, 20 Feb 2023 21:40:51 +0100 Subject: [PATCH] Add seco transforms and zhu normalization to pretrained weights (#1119) * add seco transforms and zhu normalization * adapt links * add additional comment zhu lab * left from merge --- torchgeo/models/resnet.py | 55 ++++++++++++++++++++++++++++++++++++--- torchgeo/models/vit.py | 8 +++++- 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index 6240b40405b..0dfd5d848c2 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -7,7 +7,7 @@ import kornia.augmentation as K import timm -import torch.nn as nn +import torch from timm.models import ResNet from torchvision.models._api import Weights, WeightsEnum @@ -15,8 +15,55 @@ __all__ = ["ResNet50_Weights", "ResNet18_Weights"] + +# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501 +# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501 +# Normalization either by 10K or channel-wise with band statistics _zhu_xlab_transforms = AugmentationSequential( - K.Resize(256), K.CenterCrop(224), data_keys=["image"] + K.Resize(256), + K.CenterCrop(224), + K.Normalize(mean=0, std=10000), + data_keys=["image"], +) + +# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/bigearthnet_dataset.py#L13 # noqa: E501 +_seco_transforms = AugmentationSequential( + K.Resize(128), + K.Normalize( + mean=torch.Tensor( + [ + 340.76769064, + 429.9430203, + 614.21682446, + 590.23569706, + 950.68368468, + 1792.46290469, + 2075.46795189, + 2218.94553375, + 2266.46036911, + 2246.0605464, + 1594.42694882, + 1009.32729131, + ] + ), + std=torch.Tensor( + [ + 554.81258967, + 572.41639287, + 582.87945694, + 675.88746967, + 729.89827633, + 1096.01480586, + 1273.45393088, + 1365.45589904, + 1356.13789355, + 1302.3292881, + 1079.19066363, + 818.86747235, + ] + ), + ), + data_keys=["image"], ) # https://github.com/pytorch/vision/pull/6883 @@ -62,7 +109,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] SENTINEL2_RGB_SECO = Weights( url="https://huggingface.co/torchgeo/resnet18_sentinel2_rgb_seco/resolve/main/resnet18_sentinel2_rgb_seco-9976a9cb.pth", # noqa: E501 - transforms=nn.Identity(), + transforms=_seco_transforms, meta={ "dataset": "SeCo Dataset", "in_chans": 3, @@ -137,7 +184,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] SENTINEL2_RGB_SECO = Weights( url="https://huggingface.co/torchgeo/resnet50_sentinel2_rgb_seco/resolve/main/resnet50_sentinel2_rgb_seco-584035db.pth", # noqa: E501 - transforms=nn.Identity(), + transforms=_seco_transforms, meta={ "dataset": "SeCo Dataset", "in_chans": 3, diff --git a/torchgeo/models/vit.py b/torchgeo/models/vit.py index 1f42c40405e..52fca28181e 100644 --- a/torchgeo/models/vit.py +++ b/torchgeo/models/vit.py @@ -14,8 +14,14 @@ __all__ = ["ViTSmall16_Weights"] +# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501 +# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501 +# Normalization either by 10K or channel-wise with band statistics _zhu_xlab_transforms = AugmentationSequential( - K.Resize(256), K.CenterCrop(224), data_keys=["image"] + K.Resize(256), + K.CenterCrop(224), + K.Normalize(mean=0, std=10000), + data_keys=["image"], ) # https://github.com/pytorch/vision/pull/6883