Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SeCo/BYOL: add datamodule, RandomSeasonContrast #1168

Merged
merged 14 commits into from
Mar 17, 2023
Prev Previous commit
Next Next commit
Increase coverage
  • Loading branch information
adamjstewart committed Mar 10, 2023
commit 387c0591a77368a0cec85d6522e9f68fe522c6b4
113 changes: 94 additions & 19 deletions tests/datamodules/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@

from typing import Any, Dict

import matplotlib.pyplot as plt
import pytest
import torch
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from torch import Tensor
from pytorch_lightning import Trainer

from torchgeo.datamodules import (
GeoDataModule,
Expand All @@ -23,15 +27,19 @@ def __init__(self, split: str = "train", download: bool = False) -> None:
self.res = 1

def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
return {"image": torch.arange(3 * 2 * 2).view(3, 2, 2)}
image = torch.arange(3 * 2 * 2).view(3, 2, 2)
return {"image": image, "crs": CRS.from_epsg(4326), "bbox": query}

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
return plt.figure()


class CustomGeoDataModule(GeoDataModule):
def __init__(self) -> None:
super().__init__(CustomGeoDataset, 1, 1, 1, 0, download=True)


class SamplerGeoDatModule(CustomGeoDataModule):
class SamplerGeoDataModule(CustomGeoDataModule):
def setup(self, stage: str) -> None:
self.dataset = CustomGeoDataset()
self.train_sampler = RandomGeoSampler(self.dataset, 1, 1)
Expand All @@ -40,7 +48,7 @@ def setup(self, stage: str) -> None:
self.predict_sampler = RandomGeoSampler(self.dataset, 1, 1)


class BatchSamplerGeoDatModule(CustomGeoDataModule):
class BatchSamplerGeoDataModule(CustomGeoDataModule):
def setup(self, stage: str) -> None:
self.dataset = CustomGeoDataset()
self.train_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1)
Expand All @@ -59,34 +67,66 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:
def __len__(self) -> int:
return 1

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
return plt.figure()


class CustomNonGeoDataModule(NonGeoDataModule):
def __init__(self) -> None:
super().__init__(CustomNonGeoDataset, 1, 0, download=True)

def setup(self, stage: str) -> None:
super().setup(stage)

if stage in ["predict"]:
self.predict_dataset = CustomNonGeoDataset()


class TestGeoDataModule:
@pytest.fixture(params=[SamplerGeoDataModule, BatchSamplerGeoDataModule])
def datamodule(self, request: SubRequest) -> CustomGeoDataModule:
dm = request.param()
dm.trainer = Trainer(max_epochs=1)
return dm

@pytest.mark.parametrize("stage", ["fit", "validate", "test"])
def test_setup(self, stage: str) -> None:
dm = CustomGeoDataModule()
dm.prepare_data()
dm.setup(stage)

def test_sampler(self) -> None:
dm = SamplerGeoDatModule()
dm.setup("fit")
dm.train_dataloader()
dm.val_dataloader()
dm.test_dataloader()
dm.predict_dataloader()

def test_batch_sampler(self) -> None:
dm = BatchSamplerGeoDatModule()
dm.setup("fit")
dm.train_dataloader()
dm.val_dataloader()
dm.test_dataloader()
dm.predict_dataloader()
def test_train(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup("fit")
datamodule.trainer.training = True # type: ignore[union-attr]
batch = next(iter(datamodule.train_dataloader()))
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_val(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup("validate")
datamodule.trainer.validating = True # type: ignore[union-attr]
batch = next(iter(datamodule.val_dataloader()))
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_test(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup("test")
datamodule.trainer.testing = True # type: ignore[union-attr]
batch = next(iter(datamodule.test_dataloader()))
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_predict(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup("predict")
datamodule.trainer.predicting = True # type: ignore[union-attr]
batch = next(iter(datamodule.predict_dataloader()))
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_plot(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup("validate")
datamodule.plot()
plt.close()

def test_no_datasets(self) -> None:
dm = CustomGeoDataModule()
Expand All @@ -102,12 +142,47 @@ def test_no_datasets(self) -> None:


class TestNonGeoDataModule:
@pytest.mark.parametrize("stage", ["fit", "validate", "test"])
@pytest.fixture
def datamodule(self) -> CustomNonGeoDataModule:
dm = CustomNonGeoDataModule()
dm.trainer = Trainer(max_epochs=1)
return dm

@pytest.mark.parametrize("stage", ["fit", "validate", "test", "predict"])
def test_setup(self, stage: str) -> None:
dm = CustomNonGeoDataModule()
dm.prepare_data()
dm.setup(stage)

def test_train(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup("fit")
datamodule.trainer.training = True # type: ignore[union-attr]
batch = next(iter(datamodule.train_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_val(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup("validate")
datamodule.trainer.validating = True # type: ignore[union-attr]
batch = next(iter(datamodule.val_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_test(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup("test")
datamodule.trainer.testing = True # type: ignore[union-attr]
batch = next(iter(datamodule.test_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_predict(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup("predict")
datamodule.trainer.predicting = True # type: ignore[union-attr]
batch = next(iter(datamodule.predict_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_plot(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup("validate")
datamodule.plot()
plt.close()

def test_no_datasets(self) -> None:
dm = CustomNonGeoDataModule()
msg = "CustomNonGeoDataModule.setup does not define a '{}_dataset'"
Expand Down