Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataModule: common base class to reduce code duplication #1260

Merged
merged 3 commits into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ xView2
Base Classes
------------

BaseDataModule
^^^^^^^^^^^^^^

.. autoclass:: BaseDataModule

GeoDataModule
^^^^^^^^^^^^^

Expand Down
3 changes: 2 additions & 1 deletion torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .eurosat import EuroSAT100DataModule, EuroSATDataModule
from .fair1m import FAIR1MDataModule
from .fire_risk import FireRiskDataModule
from .geo import GeoDataModule, NonGeoDataModule
from .geo import BaseDataModule, GeoDataModule, NonGeoDataModule
from .gid15 import GID15DataModule
from .inria import InriaAerialImageLabelingDataModule
from .l7irish import L7IrishDataModule
Expand Down Expand Up @@ -73,6 +73,7 @@
"Vaihingen2DDataModule",
"XView2DataModule",
# Base classes
"BaseDataModule",
"GeoDataModule",
"NonGeoDataModule",
# Utilities
Expand Down
180 changes: 65 additions & 115 deletions torchgeo/datamodules/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,76 @@
from .utils import MisconfigurationException


class GeoDataModule(LightningDataModule): # type: ignore[misc]
"""Base class for data modules containing geospatial information.
class BaseDataModule(LightningDataModule): # type: ignore[misc]
"""Base class for all TorchGeo data modules.

.. versionadded:: 0.4
.. versionadded:: 0.5
"""

mean = torch.tensor(0)
std = torch.tensor(255)

def prepare_data(self) -> None:
"""Download and prepare data.

During distributed training, this method is called only within a single process
to avoid corrupted data. This method should not set state since it is not called
on every device, use ``setup`` instead.
"""
if self.kwargs.get("download", False):
self.dataset_class(**self.kwargs)

def on_after_batch_transfer(
self, batch: dict[str, Tensor], dataloader_idx: int
) -> dict[str, Tensor]:
"""Apply batch augmentations to the batch after it is transferred to the device.

Args:
batch: A batch of data that needs to be altered or augmented.
dataloader_idx: The index of the dataloader to which the batch belongs.

Returns:
A batch of data.
"""
if self.trainer:
if self.trainer.training:
aug = self.train_aug or self.aug
elif self.trainer.validating or self.trainer.sanity_checking:
aug = self.val_aug or self.aug
elif self.trainer.testing:
aug = self.test_aug or self.aug
elif self.trainer.predicting:
aug = self.predict_aug or self.aug

batch = aug(batch)

return batch

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run the plot method of the validation dataset if one exists.

Should only be called during 'fit' or 'validate' stages as ``val_dataset``
may not exist during other stages.

Args:
*args: Arguments passed to plot method.
**kwargs: Keyword arguments passed to plot method.

Returns:
A matplotlib Figure with the image, ground truth, and predictions.
"""
dataset = self.dataset or self.val_dataset
if dataset is not None:
if hasattr(dataset, "plot"):
return dataset.plot(*args, **kwargs)


class GeoDataModule(BaseDataModule):
"""Base class for data modules containing geospatial information.

.. versionadded:: 0.4
"""

def __init__(
self,
dataset_class: type[GeoDataset],
Expand Down Expand Up @@ -100,16 +161,6 @@ def __init__(
self.test_aug: Optional[Transform] = None
self.predict_aug: Optional[Transform] = None

def prepare_data(self) -> None:
"""Download and prepare data.

During distributed training, this method is called only within a single process
to avoid corrupted data. This method should not set state since it is not called
on every device, use :meth:`setup` instead.
"""
if self.kwargs.get("download", False):
self.dataset_class(**self.kwargs)

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.

Expand Down Expand Up @@ -284,60 +335,13 @@ def transfer_batch_to_device(
batch = super().transfer_batch_to_device(batch, device, dataloader_idx)
return batch

def on_after_batch_transfer(
self, batch: dict[str, Tensor], dataloader_idx: int
) -> dict[str, Tensor]:
"""Apply batch augmentations to the batch after it is transferred to the device.

Args:
batch: A batch of data that needs to be altered or augmented.
dataloader_idx: The index of the dataloader to which the batch belongs.

Returns:
A batch of data.
"""
if self.trainer:
if self.trainer.training:
aug = self.train_aug or self.aug
elif self.trainer.validating or self.trainer.sanity_checking:
aug = self.val_aug or self.aug
elif self.trainer.testing:
aug = self.test_aug or self.aug
elif self.trainer.predicting:
aug = self.predict_aug or self.aug

batch = aug(batch)

return batch

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run the plot method of the validation dataset if one exists.

Should only be called during 'fit' or 'validate' stages as ``val_dataset``
may not exist during other stages.

Args:
*args: Arguments passed to plot method.
**kwargs: Keyword arguments passed to plot method.

Returns:
A matplotlib Figure with the image, ground truth, and predictions.
"""
dataset = self.val_dataset or self.dataset
if dataset is not None:
if hasattr(dataset, "plot"):
return dataset.plot(*args, **kwargs)


class NonGeoDataModule(LightningDataModule): # type: ignore[misc]
class NonGeoDataModule(BaseDataModule):
"""Base class for data modules lacking geospatial information.

.. versionadded:: 0.4
"""

mean = torch.tensor(0)
std = torch.tensor(255)

def __init__(
self,
dataset_class: type[NonGeoDataset],
Expand Down Expand Up @@ -386,16 +390,6 @@ def __init__(
self.test_aug: Optional[Transform] = None
self.predict_aug: Optional[Transform] = None

def prepare_data(self) -> None:
"""Download and prepare data.

During distributed training, this method is called only within a single process
to avoid corrupted data. This method should not set state since it is not called
on every device, use :meth:`setup` instead.
"""
if self.kwargs.get("download", False):
self.dataset_class(**self.kwargs)

def setup(self, stage: str) -> None:
"""Set up datasets.

Expand Down Expand Up @@ -510,47 +504,3 @@ def predict_dataloader(self) -> DataLoader[dict[str, Tensor]]:
else:
msg = f"{self.__class__.__name__}.setup does not define a 'predict_dataset'"
raise MisconfigurationException(msg)

def on_after_batch_transfer(
self, batch: dict[str, Tensor], dataloader_idx: int
) -> dict[str, Tensor]:
"""Apply batch augmentations to the batch after it is transferred to the device.

Args:
batch: A batch of data that needs to be altered or augmented.
dataloader_idx: The index of the dataloader to which the batch belongs.

Returns:
A batch of data.
"""
if self.trainer:
if self.trainer.training:
aug = self.train_aug or self.aug
elif self.trainer.validating or self.trainer.sanity_checking:
aug = self.val_aug or self.aug
elif self.trainer.testing:
aug = self.test_aug or self.aug
elif self.trainer.predicting:
aug = self.predict_aug or self.aug

batch = aug(batch)

return batch

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run the plot method of the validation dataset if one exists.

Should only be called during 'fit' or 'validate' stages as ``val_dataset``
may not exist during other stages.

Args:
*args: Arguments passed to plot method.
**kwargs: Keyword arguments passed to plot method.

Returns:
A matplotlib Figure with the image, ground truth, and predictions.
"""
dataset = self.dataset or self.val_dataset
if dataset is not None:
if hasattr(dataset, "plot"):
return dataset.plot(*args, **kwargs)