Skip to content

Commit

Permalink
DataModules: pass kwargs directly to datasets (microsoft#730)
Browse files Browse the repository at this point in the history
* Datamodules: pass kwargs directly to datasets

* Rename root_dir -> root in config files

* Fix datamodule tests

* Fix mypy

* Fix tutorial

* Specify all kwarg keys

* Fix bands vs. band_set

* root_dir -> root

* Document **kwargs
  • Loading branch information
adamjstewart authored Oct 1, 2022
1 parent baccd45 commit f79b6ec
Show file tree
Hide file tree
Showing 78 changed files with 285 additions and 343 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ All TorchGeo datasets are compatible with PyTorch data loaders, making them easy
In order to facilitate direct comparisons between results published in the literature and further reduce the boilerplate code needed to run experiments with datasets in TorchGeo, we have created PyTorch Lightning [*datamodules*](https://torchgeo.readthedocs.io/en/stable/api/datamodules.html) with well-defined train-val-test splits and [*trainers*](https://torchgeo.readthedocs.io/en/stable/api/trainers.html) for various tasks like classification, regression, and semantic segmentation. These datamodules show how to incorporate augmentations from the kornia library, include preprocessing transforms (with pre-calculated channel statistics), and let users easily experiment with hyperparameters related to the data itself (as opposed to the modeling process). Training a semantic segmentation model on the [Inria Aerial Image Labeling](https://project.inria.fr/aerialimagelabeling/) dataset is as easy as a few imports and four lines of code.

```python
datamodule = InriaAerialImageLabelingDataModule(root_dir="...", batch_size=64, num_workers=6)
datamodule = InriaAerialImageLabelingDataModule(root="...", batch_size=64, num_workers=6)
task = SemanticSegmentationTask(segmentation_model="unet", encoder_weights="imagenet", learning_rate=0.1)
trainer = Trainer(gpus=1, default_root_dir="...")

Expand Down
2 changes: 1 addition & 1 deletion conf/bigearthnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ experiment:
in_channels: 14
num_classes: 19
datamodule:
root_dir: "data/bigearthnet"
root: "data/bigearthnet"
bands: "all"
num_classes: ${experiment.module.num_classes}
batch_size: 128
Expand Down
2 changes: 1 addition & 1 deletion conf/byol.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ experiment:
learning_rate: 1e-3
learning_rate_schedule_patience: 6
datamodule:
root_dir: "data/chesapeake/cvpr"
root: "data/chesapeake/cvpr"
train_splits:
- "de-train"
val_splits:
Expand Down
2 changes: 1 addition & 1 deletion conf/chesapeake_cvpr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ experiment:
ignore_index: null
imagenet_pretraining: True
datamodule:
root_dir: "data/chesapeake/cvpr"
root: "data/chesapeake/cvpr"
train_splits:
- "de-train"
val_splits:
Expand Down
2 changes: 1 addition & 1 deletion conf/cowc_counting.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ experiment:
learning_rate_schedule_patience: 2
pretrained: True
datamodule:
root_dir: "data/cowc_counting"
root: "data/cowc_counting"
seed: 0
batch_size: 64
num_workers: 4
2 changes: 1 addition & 1 deletion conf/cyclone.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ experiment:
learning_rate_schedule_patience: 2
pretrained: True
datamodule:
root_dir: "data/cyclone"
root: "data/cyclone"
seed: 0
batch_size: 32
num_workers: 4
2 changes: 1 addition & 1 deletion conf/etci2021.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ experiment:
num_classes: 2
ignore_index: 0
datamodule:
root_dir: "data/etci2021"
root: "data/etci2021"
batch_size: 32
num_workers: 4
2 changes: 1 addition & 1 deletion conf/eurosat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ experiment:
in_channels: 13
num_classes: 10
datamodule:
root_dir: "data/eurosat"
root: "data/eurosat"
batch_size: 128
num_workers: 4
2 changes: 1 addition & 1 deletion conf/inria.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ experiment:
num_classes: 2
ignore_index: null
datamodule:
root_dir: "data/inria"
root: "data/inria"
batch_size: 2
num_workers: 32
patch_size: 512
Expand Down
2 changes: 1 addition & 1 deletion conf/landcoverai.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ experiment:
num_filters: 256
ignore_index: null
datamodule:
root_dir: "data/landcoverai"
root: "data/landcoverai"
batch_size: 32
num_workers: 4
4 changes: 2 additions & 2 deletions conf/naipchesapeake.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ experiment:
num_filters: 64
ignore_index: null
datamodule:
naip_root_dir: "data/naip"
chesapeake_root_dir: "data/chesapeake/BAYWIDE"
naip_root: "data/naip"
chesapeake_root: "data/chesapeake/BAYWIDE"
batch_size: 32
num_workers: 4
patch_size: 32
4 changes: 2 additions & 2 deletions conf/oscd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ experiment:
num_filters: 256
ignore_index: 0
datamodule:
root_dir: "data/oscd"
batch_size: 32
root: "data/oscd"
train_batch_size: 32
num_workers: 4
val_split_pct: 0.1
bands: "all"
Expand Down
2 changes: 1 addition & 1 deletion conf/resisc45.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ experiment:
in_channels: 3
num_classes: 45
datamodule:
root_dir: "data/resisc45"
root: "data/resisc45"
batch_size: 128
num_workers: 4
2 changes: 1 addition & 1 deletion conf/sen12ms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ experiment:
num_classes: 11
ignore_index: null
datamodule:
root_dir: "data/sen12ms"
root: "data/sen12ms"
band_set: "all"
batch_size: 32
num_workers: 4
Expand Down
4 changes: 2 additions & 2 deletions conf/so2sat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ experiment:
in_channels: 3
num_classes: 17
datamodule:
root_dir: "data/so2sat"
root: "data/so2sat"
batch_size: 128
num_workers: 4
bands: "rgb"
band_set: "rgb"
unsupervised_mode: False
2 changes: 1 addition & 1 deletion conf/ucmerced.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ experiment:
in_channels: 3
num_classes: 21
datamodule:
root_dir: "data/ucmerced"
root: "data/ucmerced"
batch_size: 128
num_workers: 4
2 changes: 1 addition & 1 deletion docs/tutorials/trainers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
"data_dir = os.path.join(tempfile.gettempdir(), \"cyclone_data\")\n",
"\n",
"datamodule = CycloneDataModule(\n",
" root_dir=data_dir, seed=1337, batch_size=64, num_workers=6, api_key=MLHUB_API_KEY\n",
" root=data_dir, seed=1337, batch_size=64, num_workers=6, api_key=MLHUB_API_KEY\n",
")"
]
},
Expand Down
6 changes: 3 additions & 3 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def set_up_parser() -> argparse.ArgumentParser:
"--gpu", default=0, type=int, help="GPU ID to use", metavar="ID"
)
parser.add_argument(
"--root-dir",
"--root",
required=True,
type=str,
help="root directory of the dataset for the accompanying task",
Expand Down Expand Up @@ -123,7 +123,7 @@ def main(args: argparse.Namespace) -> None:
args: command-line arguments
"""
assert os.path.exists(args.input_checkpoint)
assert os.path.exists(args.root_dir)
assert os.path.exists(args.root)
TASK = TASK_TO_MODULES_MAPPING[args.task][0]
DATAMODULE = TASK_TO_MODULES_MAPPING[args.task][1]

Expand All @@ -135,7 +135,7 @@ def main(args: argparse.Namespace) -> None:

dm = DATAMODULE( # type: ignore[call-arg]
seed=args.seed,
root_dir=args.root_dir,
root=args.root,
num_workers=args.num_workers,
batch_size=args.batch_size,
)
Expand Down
2 changes: 1 addition & 1 deletion experiments/test_chesapeakecvpr_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def main(args: argparse.Namespace) -> None:
for test_splits in ALL_TEST_SPLITS:

dm = ChesapeakeCVPRDataModule(
args.chesapeakecvpr_root,
root=args.chesapeakecvpr_root,
train_splits=["de-train"],
val_splits=["de-val"],
test_splits=test_splits,
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/bigearthnet_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ experiment:
in_channels: 14
num_classes: 19
datamodule:
root_dir: "tests/data/bigearthnet"
root: "tests/data/bigearthnet"
bands: "all"
num_classes: ${experiment.module.num_classes}
batch_size: 1
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/bigearthnet_s1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ experiment:
in_channels: 2
num_classes: 19
datamodule:
root_dir: "tests/data/bigearthnet"
root: "tests/data/bigearthnet"
bands: "s1"
num_classes: ${experiment.module.num_classes}
batch_size: 1
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/bigearthnet_s2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ experiment:
in_channels: 12
num_classes: 19
datamodule:
root_dir: "tests/data/bigearthnet"
root: "tests/data/bigearthnet"
bands: "s2"
num_classes: ${experiment.module.num_classes}
batch_size: 1
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/byol.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ experiment:
learning_rate: 1e-3
learning_rate_schedule_patience: 6
datamodule:
root_dir: "tests/data/chesapeake/cvpr"
root: "tests/data/chesapeake/cvpr"
train_splits:
- "de-test"
val_splits:
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/chesapeake_cvpr_5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ experiment:
ignore_index: null
imagenet_pretraining: False
datamodule:
root_dir: "tests/data/chesapeake/cvpr"
root: "tests/data/chesapeake/cvpr"
train_splits:
- "de-test"
val_splits:
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/chesapeake_cvpr_7.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ experiment:
ignore_index: null
imagenet_pretraining: True
datamodule:
root_dir: "tests/data/chesapeake/cvpr"
root: "tests/data/chesapeake/cvpr"
train_splits:
- "de-test"
val_splits:
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/chesapeake_cvpr_prior.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ experiment:
ignore_index: null
imagenet_pretraining: False
datamodule:
root_dir: "tests/data/chesapeake/cvpr"
root: "tests/data/chesapeake/cvpr"
train_splits:
- "de-test"
val_splits:
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/cowc_counting.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ experiment:
learning_rate_schedule_patience: 2
pretrained: True
datamodule:
root_dir: "tests/data/cowc_counting"
root: "tests/data/cowc_counting"
seed: 0
batch_size: 1
num_workers: 0
2 changes: 1 addition & 1 deletion tests/conf/cyclone.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ experiment:
learning_rate_schedule_patience: 2
pretrained: False
datamodule:
root_dir: "tests/data/cyclone"
root: "tests/data/cyclone"
seed: 0
batch_size: 1
num_workers: 0
2 changes: 1 addition & 1 deletion tests/conf/deepglobelandcover_0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ experiment:
num_filters: 1
ignore_index: null
datamodule:
root_dir: "tests/data/deepglobelandcover"
root: "tests/data/deepglobelandcover"
val_split_pct: 0.0
batch_size: 1
num_workers: 0
2 changes: 1 addition & 1 deletion tests/conf/deepglobelandcover_5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ experiment:
num_filters: 1
ignore_index: null
datamodule:
root_dir: "tests/data/deepglobelandcover"
root: "tests/data/deepglobelandcover"
val_split_pct: 0.5
batch_size: 1
num_workers: 0
2 changes: 1 addition & 1 deletion tests/conf/etci2021.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ experiment:
num_classes: 2
ignore_index: 0
datamodule:
root_dir: "tests/data/etci2021"
root: "tests/data/etci2021"
batch_size: 1
num_workers: 0
2 changes: 1 addition & 1 deletion tests/conf/eurosat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ experiment:
in_channels: 13
num_classes: 2
datamodule:
root_dir: "tests/data/eurosat"
root: "tests/data/eurosat"
batch_size: 1
num_workers: 0
2 changes: 1 addition & 1 deletion tests/conf/inria.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ experiment:
num_classes: 2
ignore_index: null
datamodule:
root_dir: "tests/data/inria"
root: "tests/data/inria"
batch_size: 1
num_workers: 0
val_split_pct: 0.2
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/landcoverai.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ experiment:
num_filters: 1
ignore_index: null
datamodule:
root_dir: "tests/data/landcoverai"
root: "tests/data/landcoverai"
batch_size: 1
num_workers: 0
4 changes: 2 additions & 2 deletions tests/conf/naipchesapeake.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ experiment:
num_filters: 1
ignore_index: null
datamodule:
naip_root_dir: "tests/data/naip"
chesapeake_root_dir: "tests/data/chesapeake/BAYWIDE"
naip_root: "tests/data/naip"
chesapeake_root: "tests/data/chesapeake/BAYWIDE"
batch_size: 2
num_workers: 0
patch_size: 32
4 changes: 2 additions & 2 deletions tests/conf/oscd_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ experiment:
num_filters: 1
ignore_index: null
datamodule:
root_dir: "tests/data/oscd"
batch_size: 1
root: "tests/data/oscd"
train_batch_size: 1
num_workers: 0
val_split_pct: 0.5
bands: "all"
Expand Down
4 changes: 2 additions & 2 deletions tests/conf/oscd_rgb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ experiment:
num_filters: 1
ignore_index: null
datamodule:
root_dir: "tests/data/oscd"
batch_size: 1
root: "tests/data/oscd"
train_batch_size: 1
num_workers: 0
val_split_pct: 0.5
bands: "rgb"
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/resisc45.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ experiment:
in_channels: 3
num_classes: 3
datamodule:
root_dir: "tests/data/resisc45"
root: "tests/data/resisc45"
batch_size: 1
num_workers: 0
2 changes: 1 addition & 1 deletion tests/conf/sen12ms_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ experiment:
num_classes: 11
ignore_index: null
datamodule:
root_dir: "tests/data/sen12ms"
root: "tests/data/sen12ms"
band_set: "all"
batch_size: 1
num_workers: 0
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/sen12ms_s1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ experiment:
num_classes: 11
ignore_index: null
datamodule:
root_dir: "tests/data/sen12ms"
root: "tests/data/sen12ms"
band_set: "s1"
batch_size: 1
num_workers: 0
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/sen12ms_s2_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ experiment:
num_classes: 11
ignore_index: null
datamodule:
root_dir: "tests/data/sen12ms"
root: "tests/data/sen12ms"
band_set: "s2-all"
batch_size: 1
num_workers: 0
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/sen12ms_s2_reduced.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ experiment:
num_classes: 11
ignore_index: null
datamodule:
root_dir: "tests/data/sen12ms"
root: "tests/data/sen12ms"
band_set: "s2-reduced"
batch_size: 1
num_workers: 0
Expand Down
Loading

0 comments on commit f79b6ec

Please sign in to comment.