Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

USAVars: implementing DataModule #441

Merged
merged 25 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
adapt module to dataset refactor
  • Loading branch information
iejMac authored and adamjstewart committed Jun 27, 2022
commit 1a53b08dcc1681f84d2dd4a3fdce50983caf7c0e
50 changes: 10 additions & 40 deletions tests/datamodules/test_usavars.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,58 +10,28 @@


class TestUSAVarsDataModule:
@pytest.fixture(
scope="class",
params=zip(
[["elevation", "population"], ["treecover"]],
[True, False],
[(0.5, 0.0), (0.0, 0.5)],
),
)
@pytest.fixture()
def datamodule(self, request: SubRequest) -> USAVarsDataModule:
labels, fixed_shuffle, split = request.param
val_split_pct, test_split_pct = split
root = os.path.join("tests", "data", "usavars")
batch_size = 1
num_workers = 0

dm = USAVarsDataModule(
root,
labels,
None,
fixed_shuffle,
batch_size,
num_workers,
val_split_pct,
test_split_pct,
)
dm = USAVarsDataModule(root, batch_size=batch_size, num_workers=num_workers)
dm.prepare_data()
dm.setup()
return dm

def test_train_dataloader(self, datamodule: USAVarsDataModule) -> None:
assert len(datamodule.train_dataloader()) == 1
assert len(datamodule.train_dataloader()) == 3
sample = next(iter(datamodule.train_dataloader()))
assert sample["labels"].shape[1] == len(datamodule.labels)
if datamodule.fixed_shuffle:
assert sample["labels"][0, 0] == 1.0
assert sample["image"].shape[0] == datamodule.batch_size

def test_val_dataloader(self, datamodule: USAVarsDataModule) -> None:
if datamodule.val_split_pct == 0.5:
assert len(datamodule.val_dataloader()) == 1
sample = next(iter(datamodule.val_dataloader()))
assert sample["labels"].shape[1] == len(datamodule.labels)
if datamodule.fixed_shuffle:
assert sample["labels"][0, 0] == 0.0
else:
assert len(datamodule.val_dataloader()) == 0
assert len(datamodule.val_dataloader()) == 2
sample = next(iter(datamodule.val_dataloader()))
assert sample["image"].shape[0] == datamodule.batch_size

def test_test_dataloader(self, datamodule: USAVarsDataModule) -> None:
if datamodule.test_split_pct == 0.5:
assert len(datamodule.test_dataloader()) == 1
sample = next(iter(datamodule.test_dataloader()))
assert sample["labels"].shape[1] == len(datamodule.labels)
if datamodule.fixed_shuffle:
assert sample["labels"][0, 0] == 0.0
else:
assert len(datamodule.test_dataloader()) == 0
assert len(datamodule.test_dataloader()) == 1
sample = next(iter(datamodule.test_dataloader()))
assert sample["image"].shape[0] == datamodule.batch_size
16 changes: 10 additions & 6 deletions torchgeo/datamodules/usavars.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torch.utils.data import DataLoader

from ..datasets import USAVars
from .utils import dataset_split


class USAVarsDataModule(pl.LightningModule):
Expand Down Expand Up @@ -51,16 +50,21 @@ def prepare_data(self) -> None:

This method is only called once per run.
"""
USAVars(self.root_dir, self.labels, checksum=False)
USAVars(self.root_dir, labels=self.labels, checksum=False)

def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main Dataset objects.

This method is called once per GPU per run.
"""
dataset = USAVars(self.root_dir, self.labels, transforms=self.transforms)
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
self.train_dataset = USAVars(
self.root_dir, "train", self.labels, transforms=self.transforms
)
self.val_dataset = USAVars(
self.root_dir, "val", self.labels, transforms=self.transforms
)
self.test_dataset = USAVars(
self.root_dir, "test", self.labels, transforms=self.transforms
)

def train_dataloader(self) -> DataLoader[Any]:
Expand All @@ -69,7 +73,7 @@ def train_dataloader(self) -> DataLoader[Any]:
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
shuffle=False,
)

def val_dataloader(self) -> DataLoader[Any]:
Expand Down