-
Notifications
You must be signed in to change notification settings - Fork 408
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add simclr and tests * add lightly to reqs * pyupgrade * Copy things from prior implementation * Add SimCLR v2 projection head * Remove kwargs * Call __init__ explicitly * Fix mypy and docs * Can't test newer setuptools * Default to output dim of model * Add memory bank * Ignore erroneous warning * Fix configs, test SSL4EO * Fix a few layer bugs * mypy fixes * kernel_size must be an integer * Fix SeCo in_channels * Get more coverage * Bump min lightly * Default logging * Test weights * mypy fix * Grab max_epochs from the trainer * max_epochs param removed * Use num_features * Remove classification head * SimCLR uses LARS, with Adam as a backup * Add warnings * Grab num features directly from model * Check if identity * Match timm model design * Capture warnings * Fix tests * Increase coverage * Fix method name * More typos * Escape regex * Newer setuptools now supported * New batch norm for every layer * Rename forward arg * Clarify usage of weights parameter Co-authored-by: Caleb Robinson <[email protected]> * Fix flake8 * Check it * Use hydra * Track average L2 normed stdev over features * SimCLR decays lr to 0 * Add lr warmup * Fix version access * Fix LinearLR * isinstance supports tuples * Comment capitalization * Require lightly 1.4.3+ * Require lightly 1.4.3+ * Bump lightly version * Add RandomGrayscale * Flake8 fixes * Placate pydocstyle * Clarify docs * Pass correct weights --------- Co-authored-by: Adam J. Stewart <[email protected]> Co-authored-by: Caleb Robinson <[email protected]>
- Loading branch information
1 parent
3cc1427
commit ef7a9ad
Showing
16 changed files
with
535 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
module: | ||
_target_: torchgeo.trainers.SimCLRTask | ||
model: "resnet18" | ||
in_channels: 4 | ||
version: 1 | ||
layers: 2 | ||
memory_bank_size: 0 | ||
|
||
datamodule: | ||
_target_: torchgeo.datamodules.ChesapeakeCVPRDataModule | ||
root: "tests/data/chesapeake/cvpr" | ||
download: false | ||
train_splits: | ||
- "de-test" | ||
val_splits: | ||
- "de-test" | ||
test_splits: | ||
- "de-test" | ||
batch_size: 2 | ||
patch_size: 64 | ||
num_workers: 0 | ||
class_set: 5 | ||
use_prior_labels: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
module: | ||
_target_: torchgeo.trainers.SimCLRTask | ||
model: "resnet18" | ||
in_channels: 3 | ||
version: 1 | ||
layers: 2 | ||
hidden_dim: 8 | ||
output_dim: 8 | ||
weight_decay: 1e-6 | ||
memory_bank_size: 0 | ||
|
||
datamodule: | ||
_target_: torchgeo.datamodules.SeasonalContrastS2DataModule | ||
root: "tests/data/seco" | ||
seasons: 1 | ||
batch_size: 2 | ||
num_workers: 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
module: | ||
_target_: torchgeo.trainers.SimCLRTask | ||
model: "resnet18" | ||
in_channels: 3 | ||
version: 2 | ||
layers: 4 | ||
hidden_dim: 8 | ||
output_dim: 8 | ||
weight_decay: 1e-4 | ||
memory_bank_size: 10 | ||
|
||
datamodule: | ||
_target_: torchgeo.datamodules.SeasonalContrastS2DataModule | ||
root: "tests/data/seco" | ||
seasons: 2 | ||
batch_size: 2 | ||
num_workers: 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
module: | ||
_target_: torchgeo.trainers.SimCLRTask | ||
model: "resnet18" | ||
in_channels: 13 | ||
version: 1 | ||
layers: 2 | ||
hidden_dim: 8 | ||
output_dim: 8 | ||
weight_decay: 1e-6 | ||
memory_bank_size: 0 | ||
|
||
datamodule: | ||
_target_: torchgeo.datamodules.SSL4EOS12DataModule | ||
root: "tests/data/ssl4eo/s12" | ||
seasons: 1 | ||
batch_size: 2 | ||
num_workers: 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
module: | ||
_target_: torchgeo.trainers.SimCLRTask | ||
model: "resnet18" | ||
in_channels: 13 | ||
version: 2 | ||
layers: 3 | ||
hidden_dim: 8 | ||
output_dim: 8 | ||
weight_decay: 1e-4 | ||
memory_bank_size: 10 | ||
|
||
datamodule: | ||
_target_: torchgeo.datamodules.SSL4EOS12DataModule | ||
root: "tests/data/ssl4eo/s12" | ||
seasons: 2 | ||
batch_size: 2 | ||
num_workers: 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
import pytest | ||
import timm | ||
import torch | ||
import torchvision | ||
from _pytest.fixtures import SubRequest | ||
from _pytest.monkeypatch import MonkeyPatch | ||
from hydra.utils import instantiate | ||
from lightning.pytorch import Trainer | ||
from omegaconf import OmegaConf | ||
from torch.nn import Module | ||
from torchvision.models._api import WeightsEnum | ||
|
||
from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2 | ||
from torchgeo.models import get_model_weights, list_models | ||
from torchgeo.trainers import SimCLRTask | ||
|
||
from .test_classification import ClassificationTestModel | ||
|
||
|
||
def create_model(*args: Any, **kwargs: Any) -> Module: | ||
return ClassificationTestModel(**kwargs) | ||
|
||
|
||
def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: | ||
state_dict: dict[str, Any] = torch.load(url) | ||
return state_dict | ||
|
||
|
||
class TestSimCLRTask: | ||
@pytest.mark.parametrize( | ||
"name", | ||
[ | ||
"chesapeake_cvpr_prior_simclr", | ||
"seco_simclr_1", | ||
"seco_simclr_2", | ||
"ssl4eo_s12_simclr_1", | ||
"ssl4eo_s12_simclr_2", | ||
], | ||
) | ||
def test_trainer( | ||
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool | ||
) -> None: | ||
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) | ||
|
||
if name.startswith("seco"): | ||
monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2) | ||
|
||
if name.startswith("ssl4eo_s12"): | ||
monkeypatch.setattr(SSL4EOS12, "__len__", lambda self: 2) | ||
|
||
# Instantiate datamodule | ||
datamodule = instantiate(conf.datamodule) | ||
|
||
# Instantiate model | ||
monkeypatch.setattr(timm, "create_model", create_model) | ||
model = instantiate(conf.module) | ||
|
||
# Instantiate trainer | ||
trainer = Trainer( | ||
accelerator="cpu", | ||
fast_dev_run=fast_dev_run, | ||
log_every_n_steps=1, | ||
max_epochs=1, | ||
) | ||
trainer.fit(model=model, datamodule=datamodule) | ||
|
||
def test_version_warnings(self) -> None: | ||
with pytest.warns(UserWarning, match="SimCLR v1 only uses 2 layers"): | ||
SimCLRTask(version=1, layers=3) | ||
with pytest.warns(UserWarning, match="SimCLR v1 does not use a memory bank"): | ||
SimCLRTask(version=1, memory_bank_size=10) | ||
with pytest.warns(UserWarning, match=r"SimCLR v2 uses 3\+ layers"): | ||
SimCLRTask(version=2, layers=2) | ||
with pytest.warns(UserWarning, match="SimCLR v2 uses a memory bank"): | ||
SimCLRTask(version=2, memory_bank_size=0) | ||
|
||
@pytest.fixture( | ||
params=[ | ||
weights for model in list_models() for weights in get_model_weights(model) | ||
] | ||
) | ||
def weights(self, request: SubRequest) -> WeightsEnum: | ||
return request.param | ||
|
||
@pytest.fixture | ||
def mocked_weights( | ||
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum | ||
) -> WeightsEnum: | ||
path = tmp_path / f"{weights}.pth" | ||
model = timm.create_model( | ||
weights.meta["model"], in_chans=weights.meta["in_chans"] | ||
) | ||
torch.save(model.state_dict(), path) | ||
try: | ||
monkeypatch.setattr(weights.value, "url", str(path)) | ||
except AttributeError: | ||
monkeypatch.setattr(weights, "url", str(path)) | ||
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) | ||
return weights | ||
|
||
def test_weight_file(self, checkpoint: str) -> None: | ||
model_kwargs: dict[str, Any] = {"model": "resnet18", "weights": checkpoint} | ||
match = "num classes .* != num classes in pretrained model" | ||
with pytest.warns(UserWarning, match=match): | ||
SimCLRTask(**model_kwargs) | ||
|
||
def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: | ||
model_kwargs: dict[str, Any] = { | ||
"model": mocked_weights.meta["model"], | ||
"weights": mocked_weights, | ||
"in_channels": mocked_weights.meta["in_chans"], | ||
} | ||
match = "num classes .* != num classes in pretrained model" | ||
with pytest.warns(UserWarning, match=match): | ||
SimCLRTask(**model_kwargs) | ||
|
||
def test_weight_str(self, mocked_weights: WeightsEnum) -> None: | ||
model_kwargs: dict[str, Any] = { | ||
"model": mocked_weights.meta["model"], | ||
"weights": str(mocked_weights), | ||
"in_channels": mocked_weights.meta["in_chans"], | ||
} | ||
match = "num classes .* != num classes in pretrained model" | ||
with pytest.warns(UserWarning, match=match): | ||
SimCLRTask(**model_kwargs) | ||
|
||
@pytest.mark.slow | ||
def test_weight_enum_download(self, weights: WeightsEnum) -> None: | ||
model_kwargs: dict[str, Any] = { | ||
"model": weights.meta["model"], | ||
"weights": weights, | ||
"in_channels": weights.meta["in_chans"], | ||
} | ||
match = "num classes .* != num classes in pretrained model" | ||
with pytest.warns(UserWarning, match=match): | ||
SimCLRTask(**model_kwargs) | ||
|
||
@pytest.mark.slow | ||
def test_weight_str_download(self, weights: WeightsEnum) -> None: | ||
model_kwargs: dict[str, Any] = { | ||
"model": weights.meta["model"], | ||
"weights": str(weights), | ||
"in_channels": weights.meta["in_chans"], | ||
} | ||
match = "num classes .* != num classes in pretrained model" | ||
with pytest.warns(UserWarning, match=match): | ||
SimCLRTask(**model_kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.