From a5db5f80ffac53e13cce6001978d24041206b316 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 11 May 2023 21:52:00 +0000 Subject: [PATCH 1/2] add gassl weights --- torchgeo/models/resnet.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index bd785e85f85..8cfa0c8a0ab 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -170,6 +170,19 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] }, ) + FMOW_RGB_GASSL = Weights( + url="https://huggingface.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/main/resnet50_fmow_rgb_gassl-44b4461b.pth", # noqa: E501 + transforms=_seco_transforms, + meta={ + "dataset": "fMoW Dataset", + "in_chans": 3, + "model": "resnet50", + "publication": "https://arxiv.org/abs/2011.09980", + "repo": "https://github.com/sustainlab-group/geography-aware-ssl", + "ssl_method": "gassl", + }, + ) + def resnet18( weights: Optional[ResNet18_Weights] = None, *args: Any, **kwargs: Any From fe3948daf8a6dc5a210c66513866ae8a1a22faab Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 12 May 2023 20:02:42 +0000 Subject: [PATCH 2/2] add transforms and update docs --- docs/api/resnet_pretrained_weights.csv | 1 + torchgeo/models/resnet.py | 37 +++++++++++++++++--------- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/docs/api/resnet_pretrained_weights.csv b/docs/api/resnet_pretrained_weights.csv index 3c4bf283853..8eef92e4e46 100644 --- a/docs/api/resnet_pretrained_weights.csv +++ b/docs/api/resnet_pretrained_weights.csv @@ -2,6 +2,7 @@ Weight,Channels,Source,Citation,BigEarthNet,EuroSAT,So2Sat,OSCD ResNet18_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,,,, ResNet18_Weights.SENTINEL2_RGB_MOCO, 3,`link `__,`link `__,,,, ResNet18_Weights.SENTINEL2_RGB_SECO, 3,`link `__,`link `__,87.27,93.14,,46.94 +ResNet50_Weights.FMOW_RGB_GASSL, 3,`link `__,`link `__,,,, ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link `__,`link `__,,,, ResNet50_Weights.SENTINEL2_ALL_DINO,13,`link `__,`link `__,90.7,99.1,63.6, ResNet50_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,91.8,99.1,60.9, diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index 8cfa0c8a0ab..082899d572e 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -41,6 +41,17 @@ data_keys=["image"], ) +# Normalization only available for RGB dataset, defined here: +# https://github.com/sustainlab-group/geography-aware-ssl/blob/main/moco_fmow/main_moco_geo%2Btp.py#L287 # noqa: E501 +_mean = torch.tensor([0.485, 0.456, 0.406]) +_std = torch.tensor([0.229, 0.224, 0.225]) +_gassl_transforms = AugmentationSequential( + K.Resize(224), + K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), + K.Normalize(mean=_mean, std=_std), + data_keys=["image"], +) + # https://github.com/pytorch/vision/pull/6883 # https://github.com/pytorch/vision/pull/7107 # Can be removed once torchvision>=0.15 is required @@ -105,6 +116,19 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] .. versionadded:: 0.4 """ + FMOW_RGB_GASSL = Weights( + url="https://huggingface.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/main/resnet50_fmow_rgb_gassl-da43d987.pth", # noqa: E501 + transforms=_gassl_transforms, + meta={ + "dataset": "fMoW Dataset", + "in_chans": 3, + "model": "resnet50", + "publication": "https://arxiv.org/abs/2011.09980", + "repo": "https://github.com/sustainlab-group/geography-aware-ssl", + "ssl_method": "gassl", + }, + ) + SENTINEL1_ALL_MOCO = Weights( url="https://huggingface.co/torchgeo/resnet50_sentinel1_all_moco/resolve/main/resnet50_sentinel1_all_moco-906e4356.pth", # noqa: E501 transforms=_zhu_xlab_transforms, @@ -170,19 +194,6 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] }, ) - FMOW_RGB_GASSL = Weights( - url="https://huggingface.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/main/resnet50_fmow_rgb_gassl-44b4461b.pth", # noqa: E501 - transforms=_seco_transforms, - meta={ - "dataset": "fMoW Dataset", - "in_chans": 3, - "model": "resnet50", - "publication": "https://arxiv.org/abs/2011.09980", - "repo": "https://github.com/sustainlab-group/geography-aware-ssl", - "ssl_method": "gassl", - }, - ) - def resnet18( weights: Optional[ResNet18_Weights] = None, *args: Any, **kwargs: Any