Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multiple versions of the So2Sat dataset #1283

Merged
merged 22 commits into from
Apr 25, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add multiple versions of the So2Sat dataset
  • Loading branch information
calebrob6 committed Apr 24, 2023
commit bbc898032ba0689ad1cbdcf9efce0de306c3b835
92 changes: 67 additions & 25 deletions torchgeo/datasets/so2sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,10 @@ class So2Sat(NonGeoDataset):
acquired by the Sentinel-1 and Sentinel-2 remote sensing satellites, and a
corresponding local climate zones (LCZ) label. The dataset is distributed over
42 cities across different continents and cultural regions of the world, and comes
with a split into fully independent, non-overlapping training, validation,
and test sets.
with in a variety of different splits.

This implementation focuses on the *2nd* version of the dataset as described in
the author's github repository https://github.com/zhu-xlab/So2Sat-LCZ42 and hosted
at https://mediatum.ub.tum.de/1483140. This version is identical to the first
version of the dataset but includes the test data. The splits are defined as
follows:

* Training: 42 cities around the world
* Validation: western half of 10 other cities covering 10 cultural zones
* Testing: eastern half of the 10 other cities
This implementation covers the *2nd* and *3rd* versions of the dataset as described
in the author's github repository https://github.com/zhu-xlab/So2Sat-LCZ42.

Dataset classes:

Expand Down Expand Up @@ -63,7 +55,8 @@ class So2Sat(NonGeoDataset):

.. note::

This dataset can be automatically downloaded using the following bash script:
The version "2" dataset can be automatically downloaded using the following bash
script:

.. code-block:: bash

Expand All @@ -74,18 +67,61 @@ class So2Sat(NonGeoDataset):

or manually downloaded from https://dataserv.ub.tum.de/index.php/s/m1483140
This download will likely take several hours.
"""

filenames = {
"train": "training.h5",
"validation": "validation.h5",
"test": "testing.h5",
The version "3*" datasets can be downloaded using the following bash script:

.. code-block:: bash

for version in random block culture_10
do
for split in training testing
do
wget -P $version/ ftp://m1613658:[email protected]/$version/$split.h5
done
done

or manually downloaded from https://mediatum.ub.tum.de/1613658
""" # noqa: E501
versions = ["2", "3_random", "3_block", "3_culture_10"]
filenames_by_version = {
"2": {
"train": "training.h5",
"validation": "validation.h5",
"test": "testing.h5",
},
"3_random": {
"train": "random/training.h5",
"test": "random/testing.h5",
},
"3_block": {
"train": "block/training.h5",
"test": "block/testing.h5",
},
"3_culture_10": {
"train": "culture_10/training.h5",
"test": "culture_10/testing.h5",
}
}
md5s = {
"train": "702bc6a9368ebff4542d791e53469244",
"validation": "71cfa6795de3e22207229d06d6f8775d",
"test": "e81426102b488623a723beab52b31a8a",
md5s_by_version = {
"2": {
"train": "702bc6a9368ebff4542d791e53469244",
"validation": "71cfa6795de3e22207229d06d6f8775d",
"test": "e81426102b488623a723beab52b31a8a",
},
"3_random": {
"train": "",
"test": "",
},
"3_block": {
"train": "",
"test": "",
},
"3_culture_10": {
"train": "702bc6a9368ebff4542d791e53469244",
"test": "58335ce34ca3a18424e19da84f2832fc",
}
}

classes = [
"Compact high rise",
"Compact mid rise",
Expand Down Expand Up @@ -141,6 +177,7 @@ class So2Sat(NonGeoDataset):
def __init__(
self,
root: str = "data",
version: str = "2",
split: str = "train",
bands: Sequence[str] = BAND_SETS["all"],
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
Expand All @@ -150,6 +187,7 @@ def __init__(

Args:
root: root directory where dataset can be found
version: one of "2", "3_random", "3_block", or "3_culture_10"
split: one of "train", "validation", or "test"
bands: a sequence of band names to use where the indices correspond to the
array index of combined Sentinel 1 and Sentinel 2
Expand All @@ -163,15 +201,18 @@ def __init__(

.. versionadded:: 0.3
The *bands* parameter.

.. versionadded:: 0.5
The *version* parameter.
"""
try:
import h5py # noqa: F401
except ImportError:
raise ImportError(
"h5py is not installed and is required to use this dataset"
)

assert split in ["train", "validation", "test"]
assert version in self.versions
assert split in self.filenames_by_version[version]

self._validate_bands(bands)
self.s1_band_indices: "np.typing.NDArray[np.int_]" = np.array(
Expand All @@ -197,11 +238,12 @@ def __init__(
self.bands = bands

self.root = root
self.version = version
self.split = split
self.transforms = transforms
self.checksum = checksum

self.fn = os.path.join(self.root, self.filenames[split])
self.fn = os.path.join(self.root, self.filenames_by_version[version][split])

if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted.")
Expand Down Expand Up @@ -256,7 +298,7 @@ def _check_integrity(self) -> bool:
Returns:
True if dataset files are found and/or MD5s match, else False
"""
md5 = self.md5s[self.split]
md5 = self.md5s_by_version[self.version][self.split]
if not check_integrity(self.fn, md5 if self.checksum else None):
return False
return True
Expand Down