Skip to content

Commit

Permalink
Add SimCLR trainer (#1252)
Browse files Browse the repository at this point in the history
* add simclr and tests

* add lightly to reqs

* pyupgrade

* Copy things from prior implementation

* Add SimCLR v2 projection head

* Remove kwargs

* Call __init__ explicitly

* Fix mypy and docs

* Can't test newer setuptools

* Default to output dim of model

* Add memory bank

* Ignore erroneous warning

* Fix configs, test SSL4EO

* Fix a few layer bugs

* mypy fixes

* kernel_size must be an integer

* Fix SeCo in_channels

* Get more coverage

* Bump min lightly

* Default logging

* Test weights

* mypy fix

* Grab max_epochs from the trainer

* max_epochs param removed

* Use num_features

* Remove classification head

* SimCLR uses LARS, with Adam as a backup

* Add warnings

* Grab num features directly from model

* Check if identity

* Match timm model design

* Capture warnings

* Fix tests

* Increase coverage

* Fix method name

* More typos

* Escape regex

* Newer setuptools now supported

* New batch norm for every layer

* Rename forward arg

* Clarify usage of weights parameter

Co-authored-by: Caleb Robinson <[email protected]>

* Fix flake8

* Check it

* Use hydra

* Track average L2 normed stdev over features

* SimCLR decays lr to 0

* Add lr warmup

* Fix version access

* Fix LinearLR

* isinstance supports tuples

* Comment capitalization

* Require lightly 1.4.3+

* Require lightly 1.4.3+

* Bump lightly version

* Add RandomGrayscale

* Flake8 fixes

* Placate pydocstyle

* Clarify docs

* Pass correct weights

---------

Co-authored-by: Adam J. Stewart <[email protected]>
Co-authored-by: Caleb Robinson <[email protected]>
  • Loading branch information
3 people authored May 11, 2023
1 parent 3cc1427 commit ef7a9ad
Show file tree
Hide file tree
Showing 16 changed files with 535 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"),
("py:class", "timm.models.resnet.ResNet"),
("py:class", "timm.models.vision_transformer.VisionTransformer"),
("py:class", "torch.optim.lr_scheduler.LRScheduler"),
("py:class", "torchvision.models._api.WeightsEnum"),
("py:class", "torchvision.models.resnet.ResNet"),
]
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies:
- isort[colors]>=5.8
- kornia>=0.6.5
- laspy>=2
- lightly>=1.4.4
- lightning>=1.8
- mypy>=0.900
- nbmake>=1.3.3
Expand Down
1 change: 1 addition & 0 deletions requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ setuptools==42.0.0
einops==0.3.0
fiona==1.8.19
kornia==0.6.5
lightly==1.4.4
lightning==1.8.0
matplotlib==3.3.3
numpy==1.19.3
Expand Down
1 change: 1 addition & 0 deletions requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ einops==0.6.1
fiona==1.9.3
kornia==0.6.12
lightning==2.0.2
lightly==1.4.4
matplotlib==3.7.1
numpy==1.24.3
pillow==9.5.0
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ install_requires =
fiona>=1.8.19,<2
# kornia 0.6.5+ required due to change in kornia.augmentation API
kornia>=0.6.5,<0.7
# lightly 1.4.4+ required for MoCo v3 support
lightly>=1.4.4
# lightning 1.8+ is first release
lightning>=1.8,<3
# matplotlib 3.3.3+ required for Python 3.9 wheels
Expand Down
23 changes: 23 additions & 0 deletions tests/conf/chesapeake_cvpr_prior_simclr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module:
_target_: torchgeo.trainers.SimCLRTask
model: "resnet18"
in_channels: 4
version: 1
layers: 2
memory_bank_size: 0

datamodule:
_target_: torchgeo.datamodules.ChesapeakeCVPRDataModule
root: "tests/data/chesapeake/cvpr"
download: false
train_splits:
- "de-test"
val_splits:
- "de-test"
test_splits:
- "de-test"
batch_size: 2
patch_size: 64
num_workers: 0
class_set: 5
use_prior_labels: True
17 changes: 17 additions & 0 deletions tests/conf/seco_simclr_1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module:
_target_: torchgeo.trainers.SimCLRTask
model: "resnet18"
in_channels: 3
version: 1
layers: 2
hidden_dim: 8
output_dim: 8
weight_decay: 1e-6
memory_bank_size: 0

datamodule:
_target_: torchgeo.datamodules.SeasonalContrastS2DataModule
root: "tests/data/seco"
seasons: 1
batch_size: 2
num_workers: 0
17 changes: 17 additions & 0 deletions tests/conf/seco_simclr_2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module:
_target_: torchgeo.trainers.SimCLRTask
model: "resnet18"
in_channels: 3
version: 2
layers: 4
hidden_dim: 8
output_dim: 8
weight_decay: 1e-4
memory_bank_size: 10

datamodule:
_target_: torchgeo.datamodules.SeasonalContrastS2DataModule
root: "tests/data/seco"
seasons: 2
batch_size: 2
num_workers: 0
17 changes: 17 additions & 0 deletions tests/conf/ssl4eo_s12_simclr_1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module:
_target_: torchgeo.trainers.SimCLRTask
model: "resnet18"
in_channels: 13
version: 1
layers: 2
hidden_dim: 8
output_dim: 8
weight_decay: 1e-6
memory_bank_size: 0

datamodule:
_target_: torchgeo.datamodules.SSL4EOS12DataModule
root: "tests/data/ssl4eo/s12"
seasons: 1
batch_size: 2
num_workers: 0
17 changes: 17 additions & 0 deletions tests/conf/ssl4eo_s12_simclr_2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module:
_target_: torchgeo.trainers.SimCLRTask
model: "resnet18"
in_channels: 13
version: 2
layers: 3
hidden_dim: 8
output_dim: 8
weight_decay: 1e-4
memory_bank_size: 10

datamodule:
_target_: torchgeo.datamodules.SSL4EOS12DataModule
root: "tests/data/ssl4eo/s12"
seasons: 2
batch_size: 2
num_workers: 0
3 changes: 2 additions & 1 deletion tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(self, in_chans: int = 3, num_classes: int = 10, **kwargs: Any) -> N
super().__init__()
self.conv1 = nn.Conv2d(in_channels=in_chans, out_channels=1, kernel_size=1)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(1, num_classes)
self.fc = nn.Linear(1, num_classes) if num_classes else nn.Identity()
self.num_features = 1

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
Expand Down
154 changes: 154 additions & 0 deletions tests/trainers/test_simclr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from pathlib import Path
from typing import Any

import pytest
import timm
import torch
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from omegaconf import OmegaConf
from torch.nn import Module
from torchvision.models._api import WeightsEnum

from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2
from torchgeo.models import get_model_weights, list_models
from torchgeo.trainers import SimCLRTask

from .test_classification import ClassificationTestModel


def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs)


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestSimCLRTask:
@pytest.mark.parametrize(
"name",
[
"chesapeake_cvpr_prior_simclr",
"seco_simclr_1",
"seco_simclr_2",
"ssl4eo_s12_simclr_1",
"ssl4eo_s12_simclr_2",
],
)
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))

if name.startswith("seco"):
monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2)

if name.startswith("ssl4eo_s12"):
monkeypatch.setattr(SSL4EOS12, "__len__", lambda self: 2)

# Instantiate datamodule
datamodule = instantiate(conf.datamodule)

# Instantiate model
monkeypatch.setattr(timm, "create_model", create_model)
model = instantiate(conf.module)

# Instantiate trainer
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule)

def test_version_warnings(self) -> None:
with pytest.warns(UserWarning, match="SimCLR v1 only uses 2 layers"):
SimCLRTask(version=1, layers=3)
with pytest.warns(UserWarning, match="SimCLR v1 does not use a memory bank"):
SimCLRTask(version=1, memory_bank_size=10)
with pytest.warns(UserWarning, match=r"SimCLR v2 uses 3\+ layers"):
SimCLRTask(version=2, layers=2)
with pytest.warns(UserWarning, match="SimCLR v2 uses a memory bank"):
SimCLRTask(version=2, memory_bank_size=0)

@pytest.fixture(
params=[
weights for model in list_models() for weights in get_model_weights(model)
]
)
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f"{weights}.pth"
model = timm.create_model(
weights.meta["model"], in_chans=weights.meta["in_chans"]
)
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, "url", str(path))
except AttributeError:
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights

def test_weight_file(self, checkpoint: str) -> None:
model_kwargs: dict[str, Any] = {"model": "resnet18", "weights": checkpoint}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
SimCLRTask(**model_kwargs)

def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
model_kwargs: dict[str, Any] = {
"model": mocked_weights.meta["model"],
"weights": mocked_weights,
"in_channels": mocked_weights.meta["in_chans"],
}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
SimCLRTask(**model_kwargs)

def test_weight_str(self, mocked_weights: WeightsEnum) -> None:
model_kwargs: dict[str, Any] = {
"model": mocked_weights.meta["model"],
"weights": str(mocked_weights),
"in_channels": mocked_weights.meta["in_chans"],
}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
SimCLRTask(**model_kwargs)

@pytest.mark.slow
def test_weight_enum_download(self, weights: WeightsEnum) -> None:
model_kwargs: dict[str, Any] = {
"model": weights.meta["model"],
"weights": weights,
"in_channels": weights.meta["in_chans"],
}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
SimCLRTask(**model_kwargs)

@pytest.mark.slow
def test_weight_str_download(self, weights: WeightsEnum) -> None:
model_kwargs: dict[str, Any] = {
"model": weights.meta["model"],
"weights": str(weights),
"in_channels": weights.meta["in_chans"],
}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
SimCLRTask(**model_kwargs)
2 changes: 2 additions & 0 deletions torchgeo/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .detection import ObjectDetectionTask
from .regression import PixelwiseRegressionTask, RegressionTask
from .segmentation import SemanticSegmentationTask
from .simclr import SimCLRTask

__all__ = (
"BYOLTask",
Expand All @@ -17,4 +18,5 @@
"PixelwiseRegressionTask",
"RegressionTask",
"SemanticSegmentationTask",
"SimCLRTask",
)
Loading

0 comments on commit ef7a9ad

Please sign in to comment.