From 6b16cca61f7243c085e14801eedab3237cf9577c Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Sun, 14 Apr 2024 21:30:46 -0500 Subject: [PATCH 01/18] add quakeset dataset --- torchgeo/datasets/__init__.py | 2 + torchgeo/datasets/quakeset.py | 213 ++++++++++++++++++++++++++++++++++ 2 files changed, 215 insertions(+) create mode 100644 torchgeo/datasets/quakeset.py diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 5f3f974e2b2..739eeeaec27 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -90,6 +90,7 @@ from .patternnet import PatternNet from .potsdam import Potsdam2D from .prisma import PRISMA +from .quakeset import QuakeSet from .reforestree import ReforesTree from .resisc45 import RESISC45 from .rwanda_field_boundary import RwandaFieldBoundary @@ -226,6 +227,7 @@ "PASTIS", "PatternNet", "Potsdam2D", + "QuakeSet", "RESISC45", "ReforesTree", "RwandaFieldBoundary", diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py new file mode 100644 index 00000000000..12a5767d249 --- /dev/null +++ b/torchgeo/datasets/quakeset.py @@ -0,0 +1,213 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""QuakeSet dataset.""" + +import os +from collections.abc import Callable + +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.figure import Figure +from torch import Tensor + +from .geo import NonGeoDataset +from .utils import DatasetNotFoundError, download_url, percentile_normalization + + +class QuakeSet(NonGeoDataset): + """QuakeSet dataset. + + `QuakeSet `__ + is a dataset for Earthquake Change Detection and Magnitude Estimation and is used + for the Seismic Monitoring and Analysis (SMAC) ECML-PKDD 2024 Discovery Challenge. + + Dataset features: + + * Sentinel-1 SAR imagery + * before/pre/post imagery of areas affected by earthquakes + * 2 multispectral bands (VV/VH) + * 356 pairs of pre and post images with 5 m per pixel resolution (512x512 px) + + Dataset format: + + * single hdf5 dataset containing images, magnitudes, hypercenters, and splits + + Dataset classes: + + 0. unaffected area + 1. earthquake affected area + + If you use this dataset in your research, please cite the following paper: + + * https://arxiv.org/abs/2403.18116 + + .. note:: + + This dataset requires the following additional library to be installed: + + * `h5py `_ to load the dataset + + .. versionadded:: 0.6 + """ + + all_bands = ["VV", "VH"] + filename = "earthquakes.h5" + url = "https://hf.co/datasets/DarthReca/quakeset/resolve/main/earthquakes.h5", + md5 = "76fc7c76b7ca56f4844d852e175e1560" + splits = ["train", "val", "test"] + + def __init__( + self, + root: str = "data", + split: str = "train", + bands: list[str] = all_bands, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new QuakeSet dataset instance. + + Args: + root: root directory where dataset can be found + split: one of "train", "val", or "test" + bands: the subset of bands to load + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If ``split`` or ``bands`` arguments are invalid. + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + assert split in self.splits + assert set(bands) <= set(self.all_bands) + + self.root = root + self.split = split + self.bands = bands + self.transforms = transforms + self.download = download + self.checksum = checksum + self.filepath = os.path.join(root, self.filename) + self.band_indices = [self.all_bands.index(b) for b in bands] + + self._verify() + + try: + import h5py # noqa: F401 + except ImportError: + raise ImportError( + "h5py is not installed and is required to use this dataset" + ) + + self.data = self._load_data() + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + sample containing image and mask + """ + image = self._load_image(index) + label = torch.tensor(self.data[index]["label"]) + magnitude = torch.tensor(self.data[index]["magnitude"]) + + sample = {"image": image, "label": label, "magnitude": magnitude} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.data) + + def _load_data(self) -> list[dict[str, str | tuple[str, str], int | float]]: + """Return the metadata for a given split. + + Returns: + the sample keys, patches, images, labels, and magnitudes + """ + import h5py + + f = h5py.File(self.filepath) + + data = [] + for k in sorted(f.keys()): + if f[k].attrs["split"] != self.split: + continue + + for patch in sorted(f[k].keys()): + if patch not in ["x", "y"]: + # positive sample + magnitude = float(f[k].attrs["magnitude"]) + data.append(dict(key=k, patch=patch, images=("pre", "post"), label=1, magnitude=magnitude)) + + # hard negative sample + if "before" in f[k][patch].keys(): + data.append(dict(key=k, patch=patch, images=("before", "pre"), label=0, magnitude=0.0)) + f.close() + return data + + def _load_image(self, index: int) -> Tensor: + """Load a single image. + + Args: + index: index to return + + Returns: + the image + """ + import h5py + + key = self.data[index]["key"] + patch = self.data[index]["patch"] + images = self.data[index]["images"] + + with h5py.File(self.filepath) as f: + pre_array = f[key][patch][images[0]][:] + post_array = f[key][patch][images[1]][:] + + # index specified bands and concatenate + pre_array = pre_array[..., self.band_indices] + post_array = post_array[..., self.band_indices] + array = np.concatenate([pre_array, post_array], axis=-1).astype(np.float32) + + tensor = torch.from_numpy(array) + # Convert from HxWxC to CxHxW + tensor = tensor.permute((2, 0, 1)) + return tensor + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the files already exist + if os.path.exists(self.filepath): + return + + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset.""" + if not os.path.exists(self.filepath): + download_url( + self.url, + self.root, + filename=self.filename, + md5=self.md5 if self.checksum else None, + ) From f0f0bfc946a471740dbfcf6cd90c3765a9d0d6a6 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Mon, 15 Apr 2024 16:12:02 -0500 Subject: [PATCH 02/18] add datamodule and tests --- docs/api/datamodules.rst | 5 ++ docs/api/datasets.rst | 5 ++ docs/api/non_geo_datasets.csv | 1 + tests/conf/quakeset.yaml | 14 ++++ tests/data/quakeset/data.py | 49 +++++++++++++ tests/data/quakeset/earthquakes.h5 | Bin 0 -> 330104 bytes tests/datasets/test_quakeset.py | 70 +++++++++++++++++++ tests/trainers/test_classification.py | 1 + torchgeo/datamodules/__init__.py | 2 + torchgeo/datamodules/quakeset.py | 42 +++++++++++ torchgeo/datasets/quakeset.py | 96 ++++++++++++++++++++++++-- 11 files changed, 279 insertions(+), 6 deletions(-) create mode 100644 tests/conf/quakeset.yaml create mode 100644 tests/data/quakeset/data.py create mode 100644 tests/data/quakeset/earthquakes.h5 create mode 100644 tests/datasets/test_quakeset.py create mode 100644 torchgeo/datamodules/quakeset.py diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index f41463b9b76..b21ee12a38e 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -128,6 +128,11 @@ Potsdam .. autoclass:: Potsdam2DDataModule +QuakeSet +^^^^^^^ + +.. autoclass:: QuakeSetDataModule + RESISC45 ^^^^^^^^ diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 341be4d4916..723a376efb1 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -348,6 +348,11 @@ Potsdam .. autoclass:: Potsdam2D +QuakeSet +^^^^^^^ + +.. autoclass:: QuakeSet + ReforesTree ^^^^^^^^^^^ diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv index a34b918b5ec..a0d4c30ad78 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/non_geo_datasets.csv @@ -29,6 +29,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `PASTIS`_,I,Sentinel-1/2,"CC-BY-4.0","2,433",19,128x128xT,10,MSI `PatternNet`_,C,Google Earth,-,"30,400",38,256x256,0.06--5,RGB `Potsdam`_,S,Aerial,-,38,6,"6,000x6,000",0.05,MSI +`QuakeSet`_,C,Sentinel-1,"OpenRAIL","3,327",2,512x512,5,SAR `ReforesTree`_,"OD, R",Aerial,"CC-BY-4.0",100,6,"4,000x4,000",0.02,RGB `RESISC45`_,C,Google Earth,"CC-BY-NC-4.0","31,500",45,256x256,0.2--30,RGB `Rwanda Field Boundary`_,S,Planetscope,"NICFI AND CC-BY-4.0",70,2,256x256,4.7,RGB + NIR diff --git a/tests/conf/quakeset.yaml b/tests/conf/quakeset.yaml new file mode 100644 index 00000000000..9d54e1b6d4f --- /dev/null +++ b/tests/conf/quakeset.yaml @@ -0,0 +1,14 @@ +model: + class_path: ClassificationTask + init_args: + loss: "ce" + model: "resnet18" + in_channels: 4 + num_classes: 2 +data: + class_path: QuakeSetDataModule + init_args: + batch_size: 2 + dict_kwargs: + root: "tests/data/quakeset" + download: false diff --git a/tests/data/quakeset/data.py b/tests/data/quakeset/data.py new file mode 100644 index 00000000000..3d6eb66938b --- /dev/null +++ b/tests/data/quakeset/data.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os + +import h5py +import numpy as np + +NUM_CHANNELS = 2 +SIZE = 32 + +np.random.seed(0) + +filename = "earthquakes.h5" + +splits = { + "train": ["611645479", "611658170"], + "validation": ["611684805", "611744956"], + "test": ["611798698", "611818836"], +} + +# Remove old data +if os.path.exists(filename): + os.remove(filename) + +# Create dataset file +data = np.random.randn(SIZE, SIZE, NUM_CHANNELS) +data = data.astype(np.float32) + + +with h5py.File(filename, "w") as f: + for split, keys in splits.items(): + for key in keys: + sample = f.create_group(key) + sample.attrs.create(name="magnitude", data=np.float32(0.0)) + sample.attrs.create(name="split", data=split) + for i in range(2): + patch = sample.create_group(f"patch_{i}") + patch.create_dataset("before", data=data) + patch.create_dataset("pre", data=data) + patch.create_dataset("post", data=data) + +# Compute checksums +with open(filename, "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"md5: {md5}") diff --git a/tests/data/quakeset/earthquakes.h5 b/tests/data/quakeset/earthquakes.h5 new file mode 100644 index 0000000000000000000000000000000000000000..71deb28b2df399744b20f490d7cba955c566f76f GIT binary patch literal 330104 zcmeI530O_r+rZB`iApG>GUOm6QISNQ^`1}}E5fCSN~%jFMO_lk1y&aLR+(Bl~5OcAFTc&R+lmAGDV{LA{G_@ClVQ{*XyhIU$$NE zUQSM;qL#w`g;Kql=$h*AqJF{(D*sUt=(ksh>^9vkgB~q2vC4H~>H@VQ@|5$psx^)tZDvGq!+euZA%kTT`tgZ*D z&qFw`lmOKR!niKFr>aXIs7f>S{cakFI;c)C=9+4~5I?&v!gH_M@Yxx4yAbsSsrUD16~cR~Y`kwM2%I>h;3)sjH$t_EWKBKchp22Za^=Xe<_& zj2XpmWz{ySJ1J~ma@*9e?BaD_8{Z1mOv{3*C_f6 zzIe;8jc1xgcdYn#i9}Lz04?=j{6&Wob|sWWt%{dGRq+>ns`{12@BONv^0yCcB>X&B z`0~G(LjSX>`oI7FcUxPI`juQATWRSS|F!+t6D}{Gk6#9h?j}N-TQ(2QN#Y-E zqp`_>ov0UZ0VQegATC0J&fTu@M6Z>&RNDn?79PX+g8sa_%pYqRWix4(1YeuPuJht1-66Zbg#$PRKV*ZpBd{vd9vQ>>1VcgQ8 z;5O_DTfHU=bt5@MroI6Ajm6NG+4133XP|RoG_(mm#%qpS3fj(nSOv}TkP^O&eYkD| zE?O22f$_J1y%w|CC-Rs<<>%l%B^4c4tVUn+2)Jjq7#HRi!j*%kcbJ7XSt^a|g!;}J}*V*zKE=QBHFCv?!&0sU1LxPP!jc09Ww?lo@#;Y(_OQ};f6 z$iBmD)cu(_{W`~&Z?CcNRC6{wt|5d@Gm-}$8U=ec2jP(Sb71|*6g<)U5qI1_Q<3j_ z5gbz(j*s)e@v9CgcAsyCOWMqXR`#*zcx^tuY<7*;{rs8Vd;bPlmC>N@nkK7$aVc7u z^@FOFCSkR`2XTOY0`zcR%>op|z{zS2llS+6WgD7cb>FG@enLxdOznYvqpfg`{e2#0 zHP>Mq+sOkyZ^0H>rYP%PRoT3n6x!?M;m85AvC40KWTM_OjEXu4eJ8BLTNyD-^R84; zYw98#^I8+mZLs9}3!R|(iCz#Msw3}T<2MM>vtS@Lg$-95!Ow-qaFI(+ zX2Mekbm=vhRo6JAm>>><;Gb8s!`-{Ugkw**eaU+Lb{(#OVE>_HE5y8 z-xG$>rrUVffX7@h^$j~~cMzlIM#c38doHq2Od7*9FYg;wO01^dJpayIZR%mDe5%O z<>^O0vb#Bt6`60u%JtF;a#Oz>3e826p{2Psic_*+?UMP-`f?q(e0Uljuebyr#aqFh z)ydHRuoSJjyy1gJ*N`V}o{mFpEAy0TaZt%LR~F`A0E?y`9OKB;K3#2V}Z8gae*T72@$8F1Qu9&WE0Z2XC$5V@5aNW+U|EvH%;LF>VxO zx&ZXIOu&0DogiRFJYKgnSMG0H2P*EnhBvO=VL#Qe#Ad@!!-K3CuxxmbOKVx-&7n(h zsq#AC7dw~NOV8l%$4J>n#!;Nt$jznyV+qrsOAj5$#))BT|Afd zm;H)O!heG5p^;p^x2vp0!AI6%t0`FN9KyL*Jb0jo55#z#$CX{OnYovqGV1zdtZr{2 zW7SUM$gUI6&_5Gvy*UZJnwzm_?eyV!g&|z>dIw*5?-=eJwum=<-U08f^n#Vu=EH?O zd!az?3BHcTxa-Uu)ald-G=}eD5rb~vF7YllXsSEB_8KMI^{b3m+3$)~%+j&46&vNz3!VmVvOn>>J6v^M zDGT#D%x?uO#E6eu6dA9a*#g_aFjnzU;nV5?%YQZoeLXwlp&B<7`k$lmz=xhNto~!R z_VyfhZt5nM(@!ifton@qd`nm9Y0#Z*aB@>Lm{0)|7LLc!m#^{Mt?gLC9ucnfuPtA8 zGaMD#-B1}(T^YB35<8|VR;C&2;8xSweDBsa=(FGoteW>Lj6a_Ot?&5Y!SH!-_T?vb zX}lAx&zb@|*9KxU`FU=1Y5;6}-AtLeV;|@xRe_o|_KN*Seb~d(Q!uA_8aFvR3l=v{ zhNp8Z&6^|F2;n}qs*vf4(x`(b%M7=)1Ry*n{jQY%iPG?8*Dz~?Bjbk;j`TKla zmYR#nPZndVMaht3=?|wqZ-M*MHsb9MA8_}9$77kr6hHJjHWHfV>;U)mhRT>$10l6fHD#B$JT^CXDY%@PD!X`e z4W7JWfJ5F)Yg6Uef#bQc$UQ{u$z34Bo;OwYNY(R zZ9If#HNab!nzDwS4$9)4MnOx-3~s!j1?JaJhLx^{%Bu^K9Ztri;gx8`&$$cBTfLLHJgNs;1FAUS{fbb#?>+X^<^j4dFqAni>I4HNhOnbs z29K4wVu9>k}9DKcE!{drrT@jZe%__jVkX5M_U&MjtpHhI^%HLv1^{A6r$zc!e4^oMFSD)HwI#~HNs zXZ^BLm~_cn@VKqS7>gFLdGP_>aIHTduT0{dYbg1bmrKz}Vb8`~KLd(~{tA!Sof)TJ zSYqVjXA194PuQ?|3NG${k8g`O#s)}+q0A+UEfw{GBh@#9P3?3=aFs1EGhiwlsT7Gp zx0hfm)1|!GjwxVUd7)x^mNDMjk`0Y-&u7c0)>K~iUCn$OXu@mF0Jc$nhaJh8p=h-) z1s@z$@M|?fv2N}NY?rzUlT+`q?1{OuC*ilbhJ6#vzmAzz8^uC(Cz0~aYbDD~tBiHq zKgXlG@33WR5Z22X%x1zNX6zG!yv1bjFx>^0jvQv@Lk+R;(Og*O;|M=j^#*ZOGyEy2 zjWTLnC3#2dV>mOax}u@Ck=)C113Mue4Kc1tjE}vEwu7{A>#~`!)MXp*65I`+?e5Ah zOo)ZNmwT|r%Oq@nrwi7tHw&tUhcKIPE&08Sk?;nqDXKqN1*!A+J`5D>*y-@tVTAUwOWF2`#P`(-Usx8~*Kb(S;uUMIVZeheAKoA1i%X56fXFOeUk!o#V~(>7(HB-b?ri2G z!vc)!eMixurxx~W*chr#je=iBnxR$xQ`Xt~3bcD_j<#LSfqq9%G_IBDP;04e<_&)b z2mR2#nDdh%?k@U$ZNPdq@O=uq8eUZ%H(?9ZnkGTFKFLhlF%mXyeeb|l&QXBeLv~Kn z8)h``k4<S&g7OOpmB6-9wFX^6%}9LgtHfB_Bx1FDmRh$x_pdRtGyTRzix>IEi;+c z^$zT@jwedo?gK;WItyb$d|;-=FWGg@-E9Vm`m?WxPSEs ztoEDA6D}`fI{sU+&)yZ>soqa``RXLRCp!n-Pjy8PLoJMP-^&^my$8~JzhTaaR`TeF z35s9FOl2{3X0x;{yCLX;7|v=xQ7oGiihge6S*Jm1Ft6T8em<}gm{*;^R#qDbKUHwW zv6i*uv-hsSDb4osGfhGu=y5!IH?B8G#hqJFLU}uL(mviAYTEUJN-Jl@Y47X`T z!1d2V+0^6PVBj5*JV0+5e_Ju1S7|j6Cs(Y&&m|n-Z}j#s5AB*fwoy9w=oAAFjY7~v zd6^}gJI>pNv;(>OHrQ^tgzF!_&c++GWiOJuqbP0+ORBVx9rHWI44QT0x9=)t4{puJ z4A0$szDHGPK5hskw=3XoVP|Fe(Lrp#yDMuHq5}ta-9v4yu8MXC!`LaENoekL7LJT; z%#!Qn!tOy`@nq4zWe2bJXC1Ge#pm@0uxqt*WWB_bA#`kOh{rSh;P_d1!r~%S$?b~D zQ`2zXrC7YAF&EaIt$;Z%KHATTZj(m8Rk=?cHb@F+p+7Z#SlOOq1>I83L)p>cD98Xl&Mg0M;=~VfV)5 z@paPG{PHm)ILRmAnh|-t&B^ZAY1%X#ALND2MznzCj=g26dOmn(-Eba~bprh|>~Yn6 zZF$y-JqnSlJsdQNR>(STXPu8UgPKETqxOhTnT|$>;qkg|Xd}_bT{D*PXvGXx!KWh3 zYQGvB_!l;;`L7B;i&d~fZUe!gh!wUrkRO_L1NzG?_{fR-*y5=hAh3@W?_1a&PW#S= z6St#qc&pkt{Q5As_VgipGI2QE$4dOy!UcTTsV``{bw2v!%4L?dJizrsCAi#EgRg5m zllL9s%`982lr_msgZZxvSVW=%EpM*IokRD-;(Lp*&aRHoG;jttIn#pI@zm!1`Ywgu z5l%dBoR-4-`2pON?7;$FRDy*!!dRy#+hM>5AK8K$iL9mM6kHyfjn94V^FgzAL9^tM z&~sx8NHpsM?ZVeH%YuAJ4%)}kSA;XY?1rqtuTIR=cp;wd>I9$TH2GAU`>^^+k|JsH zIo4&`1Qul~#myflg7o}LXj(Tk)2kqndt~qA5A0rGa!p%)YVsuJCYcUf4~By`oZ)=$ zY~J9~T$r}fns<+rU}d%)mK>i7?0KTYQLzyw?-wagyr~5fBK$G0{RbR+@i5q3Zw#78 z;xKB|67+oQ4O{jPM%le};F4bthInfOvpLSvDtBU*K|auK;yoT;X{&60q#TT-deHlI z0-moe!Z9PGnCsnqu3>f)26sEe6Y4hMgWq+6#r7N6>HRhN+zpOcWpEO!C!Pj{-Ci^8 zwl28V^$Lt_mX4zq9mbk>!y%(X0DJbEF1}qh3qMY;DgP<&sVs1x7M$I;nl)K*g{|tS z1E!Pgp=muOH}18FpFMvCwf4JU_v5k58z5sTetA%Zr z-R8abR%Ao&r(trfG(K;VSgzwEhN!2zSVRR-%?hnCdRX_@v7peCfOx1QzEdXRFRwH~ zZ-q8KY<@&$o^=QoC5RwKF#x@%nzGnyqp|9(JXm?48s@UzOgeWNn^V0J+q>;OdYGnT zc+y=yf7TlI#h?mbUMU=#C#K@dggkC^=_!|Y(SWu!*7GXeYC~arBfOzIg*R%wo!fNj zj5(F0aN5NkyT@wcK3t@jp_h!;1}1Y0BU@~ona%abhAT{Sy5lS_A09q=Jw*O=D)VLh zeApAPjSaDMhUXoeG4A|XS%g%Z>71#H*Vp#K3j27`|L;NYe;R@B+*6^G*Hb>==m7kA zV0}2cr~yVt>x1jijlA{y*R1lzu{iuxZg7+z}JOdjQ=0~+zw3B8@6xN}lP$H$>~twyDy zf1iEArpG#ge9t)^+2t8eXfYU46q|8+`*+ag#%rGaQWJmOJsGM7w}F?}>dJkc`|*A8 zZxro473ikDk*9C5!^2s(`Ao;FczDif=;*N%yZ+Q#+07&Y?HAR?#ik3;cF`=@-M9)L z_xiZQ>JD|mA!I)Dn=&2-Xw1RJ=DK)lf;pJlHdN|X+=`VPMPPIDHOI9hVN?@0sOy!8 z1KQ*>hZBiV`CTp6$bLMg)%O9*-x@;fy`uM;OIuK^2c>kH4_pqp3pM9nWlB#ccC1S) zxYc|UYjm**EIg)cf6%!TCeNOPzs`unOSlCO{uYDHM_q@XbsgcTSceV2cNSe5|@mV;n*y_ zqW8y3cC?<3vU7F{nCpb|(Yq(`hfQO#*35E!u+1nrO2kAny z>W5`#?ss7}bEd1w78fGqD#Z!{%Lp6;dFwIp1Qr0w7+MgbaP90t`XV}Hf7BrMU zZWRn+SE_>+2~jK96B-z2?F3I3M_AHx(kM4M)FyJK)u}t+?G<8@1tR&AOMGw)1Lr03k?W6T|R%i10gUO5F5 zYVQPhdnt37F_33u-eAA18N{Zw-^#oBMe&?qZ)~S+256;)o_S%|pr9G+(O@HJXgVvR zQ|sWEBga^)J8wa!ARP_19$@Vyy_o5yh0rc_BBWXY@3{ONz8&uhR>%9`!hxaK^wWI! zwM7sPo0kUXUjD`=`mJHddoE*FcAvl|R?eJKgphcv7nn1|-OGhp{oZTW_SPg#?jcVyc;m_d(qjWNR0 zl4W|=mrWfig~;JKsFf$i)jj*co|hK*r2DVvwoM{m(V(vEL+hHDu%ZA*?P?E;-xvLV z(YFcwy1{I`TIW#a#`k?tx5X%I<8w;!W-$tKlY<^Jmx8E*y#rqX^ zJ@kZKa^I&|z3CSmQRt5yJ$gVhuSd8ur8cx$oR41nDk(h%J;X=+1I&5enMe29CTp^B zCj>MZ2?cexW9zZrppoRjMm*ch6`!&luII0Tv}3+_Q=Z5C+L>fd)V_mb2lv42^G?jQ zS~4taJ_#P)FoJtK-Ld7IJMbl709%|t2}gCa=QBe5q0Q2F_Vcv;pulN145*&RhR*lG zRXcRpnJ;!Sy)TcUXY+i#Q&oqXO*g=|^M7Mi#dVeEcg}&njy?{}>p#MiH}~+@x9Y&W z9Y^qFa}WFOUei&#PgM-lx#-|~!JakKe9jAho5?;5?Fkq1=I~{*-EgMaUO4;fMLw-d zBV3_z28S#W^Qkp6_}!5%5bv-Lq2e6A`pa0@QrI4)<{j~;HtG1w?xykpPr$5E$!L10 z2in*s%Z{kvX-*=*9Sn>MtZyBMD?d1cpT;8Yda()&n%x{XvayQ&6>ZVN!46jW{K{i$+hdP$;bY+~RxhI-&R0BPNs$+!u}&XM9cRU68lOUg3Te=A z?-6eJQ)l4wFX2IhJ8*gTF5Y@;WvCTWo9X9G=9P2LLCcpX@yoLWRzn&Drxw1-jMjLe z$f#SLTQ^#UW182Mg*$41a^fsmcEC(7Q@rQzzRckVKCOh{nVQPN2BxxKkFVzse#(Fo z-TY*gt8>VITuHgh^&#Vab6JH?+hCXBN!+>5geP^KF1uIoiJ#oF1e5`r`0REPd5@>j za5iW(*tM(%Et(G3U24T4wm65-aY z`z*2MZhqFuOy>L85Sm?Y;&5}Tu6*gzi5U9RBs^g-9(O&dh08`)Rvxr!fWOI_;+rEY zWv}m_VL`YYUj+??Yg_KXw$6R{m#K%~?tyMBYtCnAT=ehLN-c9SI3dBI{;Oy>5A>wtZDC< z(}*|!&>Q`v#+W>BKR$YPhHZM#NU=-43TK5saL~E<8-xs(;fTH|V7GJ~k66%^EwIj& z)%Q<>UmO~;Q_be%^}O}C-!==+HMoc-g$^>Ue&ING(mF*QD;uoe%LxZqzu}f9HV}R( z7<%h)rr|#kH4mB=F&HbDyL=YA5?vclj%@>*ekouJ$Ga$xzc+*0k$O-sIV8ApTGMpZGuVf8zhd z|B3$-|0n)W{Ga$g@qgm~#Q%x^6aOdvPyCLv^=`%CG%@O;Inrj;)J*WQmuP*i+wzXq%qSy_;B|r{Ny%W$uor-(14hTZsP?|0n)W{Ga$g@qgm~ z#Q%x^6aOdvPyCD zzKeahZUZh_77l^&w}8DCv)L!|m_g;|;5{W39apSIU-JmKXSNs@<`%-0gQs}w{(x^A zzJbY63b4?93Nv2>K$md?P%`lqDB8~8xxqc4PGl&&%`oAvHuXSZFdWWjD4BWsEZnp` zfobGyfY!&Y@n}0^9(wc&-?ZZqOs-=A#Q%x^6aOdvPyCMk!p`(k&dXQNJ|v#6EGquEPRN+`T+XU`l1RY-=!i$EsZ<6E z{ZuU3&*+fhL19Hd8jHmxM=n0PY8$PRO@!?=i)-prlN5dZ-t%MitYmXx6IF#i|1l9L z`}|1NS3XHSt_p$h^&dJiw5St@6n_*gAtOf@cgNqKBfH|t|6%cPb?YP+Ns68j)si2e zDusS5f2gVtE;GI-)KIk$+9}obYX-s^^$!0bg;=Dix(DjZ7mHMPLad=$AyvmgD2vC0 z;ze5W`9~?N*AN+%?C)#ilxEQ#E8eq2A}Kk5R#E%nMME6%yyVLtyHr;L)#YDPsQQ)0 z@BK>p^xMZ3-WM^|Rq^^-YxR&SiqQXdyRU8kqphdBOEm8id6(o}l6OhoC3%R2PF72f)82N$NHXs7-*=v8o{qAeUGX3caOl= z{6kT#(w^0y?H`t-=ed7*c%I98bN$N1zidAgp8J^brSDVOcEWQVG49*#%5L|g*2P73 zckz`9@zGP8)q5tYu0KV6VP*Th(7w>9be#TdZ@qZb{_p1}jN5~MRULQIBvncNANhae z|B?Sk{vY{&|405G`G4g9k^e{jANhae|B?Sk{vY{&|405G`G4g9{rLVL@qgm~#Q%x^6aOdvPyCCe4!IYm;~!cI8fHPuGjAW&0Re;^idv)VCqprl)i7CKl*vxQBb0yW{CM zm$3B~;{U||iT@M-C;m_TpZGuVf8zhd|B3$-|0n)W{Ga$g@qgm~#Q%x^6aOdvPyC$LIg$C_hG&IwOL#bWVtWZ2q5BZ;Xz?VzZ#3hB^g(W`<-eVt&(ko?RAQ4>Q1Rq^!0oF zIRAiZOX|lZVf@E^l+Dv4RYzc@I!}%ekUu{^)&HKS$1X?JA0wJcN)9c&r-l4MLcSr9 zwd(i5>J$80-k*XvRlC9ks#3Tfn*WFVKl1;`|0Dm8{6F&l$p0h%kNiLK|H%I%|Bw7X z^8d*HBma;5Kl1;`|0Dm8{6F&l$p0h%kNiLK|H%I%|Bw7X^8bE(|Bv`T@qgm~#Q%x^ z6aOdvPyCq5J-^A!m53;a_` zVSWbnWF$p}UCE`3HB=FjsskaE!Ym9zT_}b18X}{TU43ny6wRVLR=j74L{f49t)lkD zi-tJjkoulTOIG~;@z>6yr0Q21zxONUe)U=Xem2sypKRp8E|B?Sk{vY{&|405G`G4g9k^e{jANhae|B?Sk{vY{&|405G`G4g9 zk^e{j-;eMA5&tLtPyCC_H@0-TDFgoC08@vtrrc|6MI%XwBN{$-;oW#5OwIGbWsdi)IW3G)jaWF=CS z*1x|8<^SGcMS4a56)zESs^bhUsfk3oB6-O-u}DW$vnZQduus5|405G`G4g9k^e{jANhae|B?Sk{vY{&|405G`G4g9k^e{jANhae|NZ#>AMt zxM*281jgS2_FBwlpU7hdm7jz6lvH$Fu^N5NBjBFdVqBP82v-iC;;s7wzHRshCPyj2 zLiZ`md=UU$#tlHp#8;qbJA>y2_kcQ)q3|}tguB|*1BJnGIG>?p=IOI=)Aj_Wk+T6> zAGgM%?TmTo(JOq@jz=)Ljs+0^C;m_TpZGuVf8zhd|B3$-|0n)W{Ga$g@qgm~#Q%x^ z6aOdvPyC6ix7}MJO7j_-;$&#qlRkf9Hr{PLO(80=iw+^P`ZCg6IFGgUCaelJ>`ZYPY7dI{fd_aDBzzijvA zs9&Y4BQ7m{@qcZ<5*NK~i^m+HzXBEO~+xVB{GaG`7RP^7R|U){JTUVDLH_adPV~c zafEt?f}#ydEr0Di3*Vcs$m)vfN`?72N}VcI{`L|0+I&U1W#<3+v*#-+N6)EPLv@#> zrEm6M`<#}YXGeNj7=cQkl~5DrLy2%ypS|!ynNoPJBd&gXe`PlNVXNh+Uo|KX{i^Kq zE%cj|Yo+@`h)WI;o;CSiPS==Hb&RnTIdfCtyTS zSojcs_3vMM&z8Jg|AliQ|Bw7X^8d*HBma;5Kl1;`|0Dm8{6F&l$p0h%kNiLK|H%I% z|Bw7X^8d*HBma;5Kl1;`|0Dm8{6F&l$p0h%kNiLK|9*V`kN7|Ff8zhd|B3$-|0n)W z{Ga$g@qgm~#Q%x^6aOdvPyCdRBiGO%;7y-d0Zc->Z6@UliG z#BMZU)0XdJkDdHb5&H>eHrUQz7)zP?_+_x@ZX%?)W%J;iB>vGh8k-#0iFyGSP?Giz z;vyvI-0d1q^je8awOznw;W3Oa=+C>${IQl%Hj`#a@U=-i4!d$Ev#0At*0Oz!Eb(#@ zd+OT|ZPU}acM}WrG~C0z%-!+yn@iYw3-N#A|HS`^{}cZw{!jd$_&@P~;{U||iT@M- zC;m_TpZGuVf8zhd|B3$-|0n)W{Ga$g@qgm~#Q%x^6aOdv|Ks!ja+Lp}OL@ru@V)#$ zxh|@+6!L`p3n><;$=@gf`XbexkLjT52QlirAb;fPs92JxV|2*yps=FrHWrIZwkY08 zm6JoOBpO0~j*=W%>dvHD(neg;PRKhV9M8&0bv$=<|NCA)jrp1SX-S~}p&w=Q*htlJ z2vPSRAuNA(o+?#dCP`72H`Su0?zcjJFWZlWaa2z|`U~xz1*`jSh^iFM_n%T0j{?Q9 zC_4_6jf?R97TO7UghZoMaU1-dfYJOv|405G`G4g9k^e{jANhae z|B?Sk{vY{&|405G`G4g9k^e{jANhae|B?Sk{-0Fs7f95lMv<{l z>@Da^>x)XP~vB6t_qhJU-A;JM@v`*yqk^!5E!`xm^FlAB$0P_f9ay&uK!9D>a zg2KXw_^W^a+IvoyRkf0p5fj{5QxvkdGUY%h~;0$#UO7QC#H z39%bZ*tF&Q*kdO@RK$M5nGLq{7sgU%K7JW2x|;}TZrMCICy9Twjm9PicA{Ru1(c+{ zgSZF@I(NIq6TMdAQf(KoS$GWN3;Of!GJmXPl+C1B5`1kEkHfCq$?WNRk+p0eBTKxT z#Gd*#MBDUq?%l)!Jq`D8FLQT1{pJ$3-a`DJ_&@P~;{U||iT@M-C;m_TpZGuVf8zhd z|B3$-|0n)W{Ga$g@qgm~#Q%x^6aOdvPyC>zZ~TyIbR;~lKhYS zKPS{@^_5f#LVk#Ar@!4!$YYX!>f7zUa+n`(bvf!+@5)2JDo61!R^{dTJq44xZ~e9T zkiM7yN6$fZmcn!M!#x0W!hcExY^|+rZQ9thZKvK_alMV5bz7_dR=2aUvug84y{(N+ zyEeA}t=`Vgww)^bm`GH#y`8n4U2EZYAx}a#Pd(I!TnFX)f_Jrzsx z_lyo19u!t|5M!~pWQ*dhROg^oa%kavN_xM#^H=SR_JplolZCpG93oR+fc(Z{Xv^&Q z@TxP=xiA{q1Rvuy$1MeI=RT~0=6Fa6-^D&$w*eO|3x~k?Tfkn6+3XW}%%Jjf@Sc*2 zjw@E9uXzOAGh2)ca|_|h!Bf0-f55j5-@xQ31z6}lg_$n`pv$-cD4F;Q6m4hl+~6Ki zCo&Y?W|(kSn|h!y7!Kz%l*~MR7H-;}z%+6;Kkm$f)bS{zFHG`in#% z#UDjW$jH&fX)e^cEkq(A|3>k;e`QhSVUmbcSusSST#0HcAwNb$JJs`Mn~dzTHmfe_P?( z?Mh$!k6SB8{mQ32^sBFplh?|OJAd{#`Ty-fH2)9H|3m&C`G4g9k^e{jANhae|B?Sk z{vY{&|405G`G4g9k^e{jANhae|B?Sk{vY{&Fg z&R-Zynfdr-u;^|gq`77D;G87>(KZ^J9N39^0T)n`_737ABqXYGeT*#eauR#$+YoKj)46vO3-mPH!@bPi z@${QZ*m?``f8zhd|B3$-|0n)W{Ga$g@qgm~#Q%x^6aOdvPyC+c+GK3LEE_xtDrd^Qo?t!57%wLMa#k=F#Z;>*J3vNL>@D!{2aWeq@v@B z)#z&;0r$)nsfB*mh literal 0 HcmV?d00001 diff --git a/tests/datasets/test_quakeset.py b/tests/datasets/test_quakeset.py new file mode 100644 index 00000000000..c352cb9f447 --- /dev/null +++ b/tests/datasets/test_quakeset.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch + +import torchgeo.datasets.utils +from torchgeo.datasets import DatasetNotFoundError, QuakeSet + + +def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: + shutil.copy(url, root) + + +class TestQuakeSet: + @pytest.fixture(params=["train", "val", "test"]) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> QuakeSet: + monkeypatch.setattr(torchgeo.datasets.fire_risk, "download_url", download_url) + url = os.path.join("tests", "data", "quakeset", "earthquakes.h5") + md5 = "127d0d6a1f82d517129535f50053a4c9" + monkeypatch.setattr(QuakeSet, "md5", md5) + monkeypatch.setattr(QuakeSet, "url", url) + root = str(tmp_path) + split = request.param + transforms = nn.Identity() + return QuakeSet(root, split, transforms, download=True, checksum=True) + + def test_getitem(self, dataset: QuakeSet) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["label"], torch.Tensor) + assert x["image"].shape[0] == 4 + + def test_len(self, dataset: QuakeSet) -> None: + assert len(dataset) == 8 + + def test_already_downloaded(self, dataset: QuakeSet, tmp_path: Path) -> None: + QuakeSet(root=str(tmp_path), download=True) + + def test_already_downloaded_not_extracted( + self, dataset: QuakeSet, tmp_path: Path + ) -> None: + shutil.rmtree(os.path.dirname(dataset.root)) + download_url(dataset.url, root=str(tmp_path)) + QuakeSet(root=str(tmp_path), download=False) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + QuakeSet(str(tmp_path)) + + def test_plot(self, dataset: QuakeSet) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["label"].clone() + dataset.plot(x) + plt.close() diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 02183978995..fcf21069fb5 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -76,6 +76,7 @@ class TestClassificationTask: "eurosat", "eurosat100", "fire_risk", + "quakeset", "resisc45", "so2sat_all", "so2sat_s1", diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 5ee3a47aaaa..7f0ee8a263d 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -26,6 +26,7 @@ from .nasa_marine_debris import NASAMarineDebrisDataModule from .oscd import OSCDDataModule from .potsdam import Potsdam2DDataModule +from .quakeset import QuakeSetDataModule from .resisc45 import RESISC45DataModule from .seco import SeasonalContrastS2DataModule from .sen12ms import SEN12MSDataModule @@ -76,6 +77,7 @@ "NASAMarineDebrisDataModule", "OSCDDataModule", "Potsdam2DDataModule", + "QuakeSetDataModule", "RESISC45DataModule", "SeasonalContrastS2DataModule", "SEN12MSDataModule", diff --git a/torchgeo/datamodules/quakeset.py b/torchgeo/datamodules/quakeset.py new file mode 100644 index 00000000000..1963ba48ae2 --- /dev/null +++ b/torchgeo/datamodules/quakeset.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""QuakeSet datamodule.""" + +from typing import Any + +import kornia.augmentation as K +import torch + +from ..datasets import QuakeSet +from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule + + +class QuakeSetDataModule(NonGeoDataModule): + """LightningDataModule implementation for the QuakeSet dataset. + + .. versionadded:: 0.6 + """ + + mean = torch.tensor(0.0) + std = torch.tensor(1.0) + + def __init__( + self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + ) -> None: + """Initialize a new QuakeSetDataModule instance. + + Args: + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.QuakeSet`. + """ + super().__init__(QuakeSet, batch_size, num_workers, **kwargs) + self.train_aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + data_keys=["image"], + ) diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index 12a5767d249..ace4d0bd39c 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -5,6 +5,7 @@ import os from collections.abc import Callable +from typing import cast import matplotlib.pyplot as plt import numpy as np @@ -28,7 +29,9 @@ class QuakeSet(NonGeoDataset): * Sentinel-1 SAR imagery * before/pre/post imagery of areas affected by earthquakes * 2 multispectral bands (VV/VH) - * 356 pairs of pre and post images with 5 m per pixel resolution (512x512 px) + * 3,327 pairs of pre and post images with 5 m per pixel resolution (512x512 px) + * 2 classification labels (unaffected / affected by earthquake) + * earthquake magnitudes for each sample Dataset format: @@ -54,9 +57,10 @@ class QuakeSet(NonGeoDataset): all_bands = ["VV", "VH"] filename = "earthquakes.h5" - url = "https://hf.co/datasets/DarthReca/quakeset/resolve/main/earthquakes.h5", + url = ("https://hf.co/datasets/DarthReca/quakeset/resolve/main/earthquakes.h5",) md5 = "76fc7c76b7ca56f4844d852e175e1560" - splits = ["train", "val", "test"] + splits = {"train": "train", "val": "validation", "test": "test"} + classes = ["unaffected_area", "earthquake_affected_area"] def __init__( self, @@ -145,18 +149,34 @@ def _load_data(self) -> list[dict[str, str | tuple[str, str], int | float]]: data = [] for k in sorted(f.keys()): - if f[k].attrs["split"] != self.split: + if f[k].attrs["split"] != self.splits[self.split]: continue for patch in sorted(f[k].keys()): if patch not in ["x", "y"]: # positive sample magnitude = float(f[k].attrs["magnitude"]) - data.append(dict(key=k, patch=patch, images=("pre", "post"), label=1, magnitude=magnitude)) + data.append( + dict( + key=k, + patch=patch, + images=("pre", "post"), + label=1, + magnitude=magnitude, + ) + ) # hard negative sample if "before" in f[k][patch].keys(): - data.append(dict(key=k, patch=patch, images=("before", "pre"), label=0, magnitude=0.0)) + data.append( + dict( + key=k, + patch=patch, + images=("before", "pre"), + label=0, + magnitude=0.0, + ) + ) f.close() return data @@ -177,7 +197,9 @@ def _load_image(self, index: int) -> Tensor: with h5py.File(self.filepath) as f: pre_array = f[key][patch][images[0]][:] + pre_array = np.nan_to_num(pre_array, nan=0) post_array = f[key][patch][images[1]][:] + post_array = np.nan_to_num(post_array, nan=0) # index specified bands and concatenate pre_array = pre_array[..., self.band_indices] @@ -211,3 +233,65 @@ def _download(self) -> None: filename=self.filename, md5=self.md5 if self.checksum else None, ) + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional suptitle to use for figure + + Returns: + a matplotlib Figure with the rendered sample + """ + image = sample["image"].permute((1, 2, 0)).numpy() + label = cast(int, sample["label"].item()) + label_class = self.classes[label] + + # Create false color image for pre image + vv = percentile_normalization(image[..., 0]) + 1e-16 + vh = percentile_normalization(image[..., 1]) + 1e-16 + pre_fci = np.stack([vv, vh, vv / vh], axis=-1).clip(0, 1) + + # Create false color image for post image + vv = percentile_normalization(image[..., 2]) + 1e-16 + vh = percentile_normalization(image[..., 3]) + 1e-16 + post_fci = np.stack([vv, vh, vv / vh], axis=-1).clip(0, 1) + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction = cast(int, sample["prediction"].item()) + prediction_class = self.classes[prediction] + + ncols = 2 + fig, axs = plt.subplots( + nrows=1, + ncols=ncols, + figsize=(10, ncols * 5), + sharex=True, + layout="constrained", + ) + + axs[0].imshow(pre_fci) + axs[0].axis("off") + axs[0].set_title("Image Pre") + axs[1].imshow(post_fci) + axs[1].axis("off") + axs[1].set_title("Image Post") + + if show_titles: + title = f"Label: {label_class}" + if showing_predictions: + title += f"\nPrediction: {prediction_class}" + fig.supxlabel(title) + + if suptitle is not None: + fig.suptitle(suptitle) + + return fig From 5e453c076e8f6034c3a55b74c3ec61b72d804ff1 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Mon, 15 Apr 2024 16:13:47 -0500 Subject: [PATCH 03/18] update plot --- torchgeo/datasets/quakeset.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index ace4d0bd39c..cd0ef893a72 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -271,11 +271,7 @@ def plot( ncols = 2 fig, axs = plt.subplots( - nrows=1, - ncols=ncols, - figsize=(10, ncols * 5), - sharex=True, - layout="constrained", + nrows=1, ncols=ncols, figsize=(ncols * 5, 10), sharex=True ) axs[0].imshow(pre_fci) @@ -287,11 +283,14 @@ def plot( if show_titles: title = f"Label: {label_class}" + if "magnitude" in sample: + magnitude = cast(float, sample["magnitude"].item()) + title += f"\nMagnitude: {magnitude:.2f}" if showing_predictions: title += f"\nPrediction: {prediction_class}" - fig.supxlabel(title) + fig.supxlabel(title, y=0.22) if suptitle is not None: - fig.suptitle(suptitle) + fig.suptitle(suptitle, y=0.8) return fig From 4baf84ecdc4c4f79ce1b62c48804c7046544b028 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Mon, 15 Apr 2024 16:16:36 -0500 Subject: [PATCH 04/18] add plot title spacing --- torchgeo/datasets/quakeset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index cd0ef893a72..520f189a4dc 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -285,7 +285,7 @@ def plot( title = f"Label: {label_class}" if "magnitude" in sample: magnitude = cast(float, sample["magnitude"].item()) - title += f"\nMagnitude: {magnitude:.2f}" + title += f" | Magnitude: {magnitude:.2f}" if showing_predictions: title += f"\nPrediction: {prediction_class}" fig.supxlabel(title, y=0.22) @@ -293,4 +293,6 @@ def plot( if suptitle is not None: fig.suptitle(suptitle, y=0.8) + fig.tight_layout() + return fig From f49438f670bed682303ae2da47e1492ebc43fc2f Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Mon, 15 Apr 2024 16:59:13 -0500 Subject: [PATCH 05/18] fix tests --- tests/datasets/test_quakeset.py | 4 +++- torchgeo/datasets/quakeset.py | 15 +++------------ 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/tests/datasets/test_quakeset.py b/tests/datasets/test_quakeset.py index c352cb9f447..f625e3a5cc9 100644 --- a/tests/datasets/test_quakeset.py +++ b/tests/datasets/test_quakeset.py @@ -33,7 +33,9 @@ def dataset( root = str(tmp_path) split = request.param transforms = nn.Identity() - return QuakeSet(root, split, transforms, download=True, checksum=True) + return QuakeSet( + root, split, transforms=transforms, download=True, checksum=True + ) def test_getitem(self, dataset: QuakeSet) -> None: x = dataset[0] diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index 520f189a4dc..c8560bb0b73 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -55,7 +55,6 @@ class QuakeSet(NonGeoDataset): .. versionadded:: 0.6 """ - all_bands = ["VV", "VH"] filename = "earthquakes.h5" url = ("https://hf.co/datasets/DarthReca/quakeset/resolve/main/earthquakes.h5",) md5 = "76fc7c76b7ca56f4844d852e175e1560" @@ -66,7 +65,6 @@ def __init__( self, root: str = "data", split: str = "train", - bands: list[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -76,27 +74,23 @@ def __init__( Args: root: root directory where dataset can be found split: one of "train", "val", or "test" - bands: the subset of bands to load transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - AssertionError: If ``split`` or ``bands`` arguments are invalid. + AssertionError: If ``split`` argument is invalid. DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.splits - assert set(bands) <= set(self.all_bands) self.root = root self.split = split - self.bands = bands self.transforms = transforms self.download = download self.checksum = checksum self.filepath = os.path.join(root, self.filename) - self.band_indices = [self.all_bands.index(b) for b in bands] self._verify() @@ -200,11 +194,8 @@ def _load_image(self, index: int) -> Tensor: pre_array = np.nan_to_num(pre_array, nan=0) post_array = f[key][patch][images[1]][:] post_array = np.nan_to_num(post_array, nan=0) - - # index specified bands and concatenate - pre_array = pre_array[..., self.band_indices] - post_array = post_array[..., self.band_indices] - array = np.concatenate([pre_array, post_array], axis=-1).astype(np.float32) + array = np.concatenate([pre_array, post_array], axis=-1) + array = array.astype(np.float32) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW From 3cf04fb9f0395c5091ca1ec932488cfe467c724f Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Mon, 15 Apr 2024 17:04:39 -0500 Subject: [PATCH 06/18] fix tests finally --- tests/datasets/test_quakeset.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/datasets/test_quakeset.py b/tests/datasets/test_quakeset.py index f625e3a5cc9..b41634084cc 100644 --- a/tests/datasets/test_quakeset.py +++ b/tests/datasets/test_quakeset.py @@ -1,9 +1,11 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import builtins import os import shutil from pathlib import Path +from typing import Any import matplotlib.pyplot as plt import pytest @@ -25,7 +27,7 @@ class TestQuakeSet: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> QuakeSet: - monkeypatch.setattr(torchgeo.datasets.fire_risk, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.quakeset, "download_url", download_url) url = os.path.join("tests", "data", "quakeset", "earthquakes.h5") md5 = "127d0d6a1f82d517129535f50053a4c9" monkeypatch.setattr(QuakeSet, "md5", md5) @@ -37,6 +39,17 @@ def dataset( root, split, transforms=transforms, download=True, checksum=True ) + @pytest.fixture + def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: + import_orig = builtins.__import__ + + def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: + if name == "h5py": + raise ImportError() + return import_orig(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mocked_import) + def test_getitem(self, dataset: QuakeSet) -> None: x = dataset[0] assert isinstance(x, dict) @@ -50,13 +63,6 @@ def test_len(self, dataset: QuakeSet) -> None: def test_already_downloaded(self, dataset: QuakeSet, tmp_path: Path) -> None: QuakeSet(root=str(tmp_path), download=True) - def test_already_downloaded_not_extracted( - self, dataset: QuakeSet, tmp_path: Path - ) -> None: - shutil.rmtree(os.path.dirname(dataset.root)) - download_url(dataset.url, root=str(tmp_path)) - QuakeSet(root=str(tmp_path), download=False) - def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match="Dataset not found"): QuakeSet(str(tmp_path)) @@ -68,5 +74,6 @@ def test_plot(self, dataset: QuakeSet) -> None: dataset.plot(x, show_titles=False) plt.close() x["prediction"] = x["label"].clone() + x["magnitude"] = torch.tensor(0.0) dataset.plot(x) plt.close() From b3c238da419348d1a6e273a671529d864c5ecfee Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Mon, 15 Apr 2024 17:11:00 -0500 Subject: [PATCH 07/18] fix mypy --- torchgeo/datasets/quakeset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index c8560bb0b73..84b621eb42b 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -131,7 +131,7 @@ def __len__(self) -> int: """ return len(self.data) - def _load_data(self) -> list[dict[str, str | tuple[str, str], int | float]]: + def _load_data(self) -> list[dict[str, str | tuple[str, str] | int | float]]: """Return the metadata for a given split. Returns: From 9380da0f188773527bcd65a9f492bc82c0b81955 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Mon, 15 Apr 2024 17:14:04 -0500 Subject: [PATCH 08/18] fix url --- torchgeo/datasets/quakeset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index 84b621eb42b..d24e79fb095 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -56,7 +56,7 @@ class QuakeSet(NonGeoDataset): """ filename = "earthquakes.h5" - url = ("https://hf.co/datasets/DarthReca/quakeset/resolve/main/earthquakes.h5",) + url = "https://hf.co/datasets/DarthReca/quakeset/resolve/main/earthquakes.h5" md5 = "76fc7c76b7ca56f4844d852e175e1560" splits = {"train": "train", "val": "validation", "test": "test"} classes = ["unaffected_area", "earthquake_affected_area"] From bcf07e14300fee149fdf36b9295709281029d144 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Mon, 15 Apr 2024 17:25:40 -0500 Subject: [PATCH 09/18] fix mypy --- torchgeo/datasets/quakeset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index d24e79fb095..53c1116348d 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable -from typing import cast +from typing import Any, cast import matplotlib.pyplot as plt import numpy as np @@ -131,7 +131,7 @@ def __len__(self) -> int: """ return len(self.data) - def _load_data(self) -> list[dict[str, str | tuple[str, str] | int | float]]: + def _load_data(self) -> list[dict[str, Any]]: """Return the metadata for a given split. Returns: From b9e0e5306080e58a07d332e26fa0725860cec962 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Mon, 15 Apr 2024 17:29:01 -0500 Subject: [PATCH 10/18] pin hf url to commit --- torchgeo/datasets/quakeset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index 53c1116348d..899696c290f 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -56,7 +56,7 @@ class QuakeSet(NonGeoDataset): """ filename = "earthquakes.h5" - url = "https://hf.co/datasets/DarthReca/quakeset/resolve/main/earthquakes.h5" + url = "https://hf.co/datasets/DarthReca/quakeset/resolve/bead1d25fb9979dbf703f9ede3e8b349f73b29f7/earthquakes.h5" md5 = "76fc7c76b7ca56f4844d852e175e1560" splits = {"train": "train", "val": "validation", "test": "test"} classes = ["unaffected_area", "earthquake_affected_area"] From 1b027d25e8c0a0d566958bd867436271605d43df Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Mon, 15 Apr 2024 18:42:27 -0500 Subject: [PATCH 11/18] fix docs --- docs/api/datamodules.rst | 2 +- docs/api/datasets.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index b21ee12a38e..ef7a91b937d 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -129,7 +129,7 @@ Potsdam .. autoclass:: Potsdam2DDataModule QuakeSet -^^^^^^^ +^^^^^^^^ .. autoclass:: QuakeSetDataModule diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 723a376efb1..e3241b6315c 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -349,7 +349,7 @@ Potsdam .. autoclass:: Potsdam2D QuakeSet -^^^^^^^ +^^^^^^^^ .. autoclass:: QuakeSet From 2454baa52370509de41c85bd0c14c70d2be82b8d Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Tue, 16 Apr 2024 12:36:51 -0500 Subject: [PATCH 12/18] update dataset docs --- docs/api/non_geo_datasets.csv | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv index a0d4c30ad78..2dac9021daa 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/non_geo_datasets.csv @@ -29,7 +29,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `PASTIS`_,I,Sentinel-1/2,"CC-BY-4.0","2,433",19,128x128xT,10,MSI `PatternNet`_,C,Google Earth,-,"30,400",38,256x256,0.06--5,RGB `Potsdam`_,S,Aerial,-,38,6,"6,000x6,000",0.05,MSI -`QuakeSet`_,C,Sentinel-1,"OpenRAIL","3,327",2,512x512,5,SAR +`QuakeSet`_,"C, R",Sentinel-1,"OpenRAIL","3,327",2,512x512,10,SAR `ReforesTree`_,"OD, R",Aerial,"CC-BY-4.0",100,6,"4,000x4,000",0.02,RGB `RESISC45`_,C,Google Earth,"CC-BY-NC-4.0","31,500",45,256x256,0.2--30,RGB `Rwanda Field Boundary`_,S,Planetscope,"NICFI AND CC-BY-4.0",70,2,256x256,4.7,RGB + NIR From e05c2f3859c7f31b27326d26e7396e8cd4c81110 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Tue, 16 Apr 2024 12:37:17 -0500 Subject: [PATCH 13/18] add missing h5py test --- tests/datasets/test_quakeset.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/datasets/test_quakeset.py b/tests/datasets/test_quakeset.py index b41634084cc..0c361aa27a9 100644 --- a/tests/datasets/test_quakeset.py +++ b/tests/datasets/test_quakeset.py @@ -50,6 +50,15 @@ def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: monkeypatch.setattr(builtins, "__import__", mocked_import) + def test_mock_missing_module( + self, dataset: QuakeSet, tmp_path: Path, mock_missing_module: None + ) -> None: + with pytest.raises( + ImportError, + match="h5py is not installed and is required to use this dataset", + ): + QuakeSet(dataset.root, download=True, checksum=True) + def test_getitem(self, dataset: QuakeSet) -> None: x = dataset[0] assert isinstance(x, dict) From 7619d5a7588f00c47c0929493acbc79b51b273c7 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Wed, 17 Apr 2024 16:18:49 -0500 Subject: [PATCH 14/18] fixes per suggestions --- torchgeo/datasets/quakeset.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index 899696c290f..e32da304b76 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -28,9 +28,11 @@ class QuakeSet(NonGeoDataset): * Sentinel-1 SAR imagery * before/pre/post imagery of areas affected by earthquakes - * 2 multispectral bands (VV/VH) + * 2 SAR bands (VV/VH) * 3,327 pairs of pre and post images with 5 m per pixel resolution (512x512 px) * 2 classification labels (unaffected / affected by earthquake) + * pre/post image pairs represent earthquake affected areas + * before/pre image pairs represent hard negative unaffected areas * earthquake magnitudes for each sample Dataset format: @@ -245,15 +247,15 @@ def plot( label = cast(int, sample["label"].item()) label_class = self.classes[label] - # Create false color image for pre image + # Create false color image for image1 vv = percentile_normalization(image[..., 0]) + 1e-16 vh = percentile_normalization(image[..., 1]) + 1e-16 - pre_fci = np.stack([vv, vh, vv / vh], axis=-1).clip(0, 1) + fci1 = np.stack([vv, vh, vv / vh], axis=-1).clip(0, 1) - # Create false color image for post image + # Create false color image for image2 vv = percentile_normalization(image[..., 2]) + 1e-16 vh = percentile_normalization(image[..., 3]) + 1e-16 - post_fci = np.stack([vv, vh, vv / vh], axis=-1).clip(0, 1) + fci2 = np.stack([vv, vh, vv / vh], axis=-1).clip(0, 1) showing_predictions = "prediction" in sample if showing_predictions: @@ -265,10 +267,10 @@ def plot( nrows=1, ncols=ncols, figsize=(ncols * 5, 10), sharex=True ) - axs[0].imshow(pre_fci) + axs[0].imshow(fci1) axs[0].axis("off") axs[0].set_title("Image Pre") - axs[1].imshow(post_fci) + axs[1].imshow(fci2) axs[1].axis("off") axs[1].set_title("Image Post") From 17f46537e196460e71c8156c2af8260b235f87e1 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 18 Apr 2024 08:56:08 -0500 Subject: [PATCH 15/18] updates per suggestions x3 --- tests/datasets/test_quakeset.py | 2 +- tests/trainers/test_classification.py | 2 +- torchgeo/datasets/quakeset.py | 54 +++++++++++++-------------- 3 files changed, 28 insertions(+), 30 deletions(-) diff --git a/tests/datasets/test_quakeset.py b/tests/datasets/test_quakeset.py index 0c361aa27a9..15bd5a2d127 100644 --- a/tests/datasets/test_quakeset.py +++ b/tests/datasets/test_quakeset.py @@ -23,7 +23,7 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestQuakeSet: - @pytest.fixture(params=["train", "val", "test"]) + @pytest.fixture(params=["train", "validation", "test"]) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> QuakeSet: diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index fcf21069fb5..fb176a53071 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -88,7 +88,7 @@ class TestClassificationTask: def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: - if name.startswith("so2sat"): + if name.startswith("so2sat") or name == "quakeset": pytest.importorskip("h5py", minversion="3") config = os.path.join("tests", "conf", name + ".yaml") diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index e32da304b76..96d62539df2 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -60,7 +60,7 @@ class QuakeSet(NonGeoDataset): filename = "earthquakes.h5" url = "https://hf.co/datasets/DarthReca/quakeset/resolve/bead1d25fb9979dbf703f9ede3e8b349f73b29f7/earthquakes.h5" md5 = "76fc7c76b7ca56f4844d852e175e1560" - splits = {"train": "train", "val": "validation", "test": "test"} + splits = ["train", "validation", "test"] classes = ["unaffected_area", "earthquake_affected_area"] def __init__( @@ -75,7 +75,7 @@ def __init__( Args: root: root directory where dataset can be found - split: one of "train", "val", or "test" + split: one of "train", "validation", or "test" transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory @@ -141,39 +141,37 @@ def _load_data(self) -> list[dict[str, Any]]: """ import h5py - f = h5py.File(self.filepath) - data = [] - for k in sorted(f.keys()): - if f[k].attrs["split"] != self.splits[self.split]: - continue - - for patch in sorted(f[k].keys()): - if patch not in ["x", "y"]: - # positive sample - magnitude = float(f[k].attrs["magnitude"]) - data.append( - dict( - key=k, - patch=patch, - images=("pre", "post"), - label=1, - magnitude=magnitude, - ) - ) - - # hard negative sample - if "before" in f[k][patch].keys(): + with h5py.File(self.filepath) as f: + for k in sorted(f.keys()): + if f[k].attrs["split"] != self.split: + continue + + for patch in sorted(f[k].keys()): + if patch not in ["x", "y"]: + # positive sample + magnitude = float(f[k].attrs["magnitude"]) data.append( dict( key=k, patch=patch, - images=("before", "pre"), - label=0, - magnitude=0.0, + images=("pre", "post"), + label=1, + magnitude=magnitude, ) ) - f.close() + + # hard negative sample + if "before" in f[k][patch].keys(): + data.append( + dict( + key=k, + patch=patch, + images=("before", "pre"), + label=0, + magnitude=0.0, + ) + ) return data def _load_image(self, index: int) -> Tensor: From a2de38327b0480fd1daaf6a6633ca83af6580a57 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 18 Apr 2024 09:11:24 -0500 Subject: [PATCH 16/18] add setup method to define validation split --- torchgeo/datamodules/quakeset.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/torchgeo/datamodules/quakeset.py b/torchgeo/datamodules/quakeset.py index 1963ba48ae2..080a7f458df 100644 --- a/torchgeo/datamodules/quakeset.py +++ b/torchgeo/datamodules/quakeset.py @@ -40,3 +40,20 @@ def __init__( K.RandomVerticalFlip(p=0.5), data_keys=["image"], ) + + def setup(self, stage: str) -> None: + """Set up datasets. + + Called at the beginning of fit, validate, test, or predict. During distributed + training, this method is called from every process across all the nodes. Setting + state here is recommended. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + if stage in ["fit"]: + self.train_dataset = QuakeSet(split="train", **self.kwargs) + if stage in ["fit", "validate"]: + self.val_dataset = QuakeSet(split="validation", **self.kwargs) + if stage in ["test"]: + self.test_dataset = QuakeSet(split="test", **self.kwargs) From 927859935fd86c95d2a32f4c6277b7addc4719ac Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 18 Apr 2024 09:15:52 -0500 Subject: [PATCH 17/18] undo split renaming --- tests/datasets/test_quakeset.py | 2 +- torchgeo/datamodules/quakeset.py | 17 ----------------- torchgeo/datasets/quakeset.py | 6 +++--- 3 files changed, 4 insertions(+), 21 deletions(-) diff --git a/tests/datasets/test_quakeset.py b/tests/datasets/test_quakeset.py index 15bd5a2d127..0c361aa27a9 100644 --- a/tests/datasets/test_quakeset.py +++ b/tests/datasets/test_quakeset.py @@ -23,7 +23,7 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestQuakeSet: - @pytest.fixture(params=["train", "validation", "test"]) + @pytest.fixture(params=["train", "val", "test"]) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> QuakeSet: diff --git a/torchgeo/datamodules/quakeset.py b/torchgeo/datamodules/quakeset.py index 080a7f458df..1963ba48ae2 100644 --- a/torchgeo/datamodules/quakeset.py +++ b/torchgeo/datamodules/quakeset.py @@ -40,20 +40,3 @@ def __init__( K.RandomVerticalFlip(p=0.5), data_keys=["image"], ) - - def setup(self, stage: str) -> None: - """Set up datasets. - - Called at the beginning of fit, validate, test, or predict. During distributed - training, this method is called from every process across all the nodes. Setting - state here is recommended. - - Args: - stage: Either 'fit', 'validate', 'test', or 'predict'. - """ - if stage in ["fit"]: - self.train_dataset = QuakeSet(split="train", **self.kwargs) - if stage in ["fit", "validate"]: - self.val_dataset = QuakeSet(split="validation", **self.kwargs) - if stage in ["test"]: - self.test_dataset = QuakeSet(split="test", **self.kwargs) diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index 96d62539df2..edf231cacbd 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -60,7 +60,7 @@ class QuakeSet(NonGeoDataset): filename = "earthquakes.h5" url = "https://hf.co/datasets/DarthReca/quakeset/resolve/bead1d25fb9979dbf703f9ede3e8b349f73b29f7/earthquakes.h5" md5 = "76fc7c76b7ca56f4844d852e175e1560" - splits = ["train", "validation", "test"] + splits = {"train": "train", "val": "validation", "test": "test"} classes = ["unaffected_area", "earthquake_affected_area"] def __init__( @@ -75,7 +75,7 @@ def __init__( Args: root: root directory where dataset can be found - split: one of "train", "validation", or "test" + split: one of "train", "val", or "test" transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory @@ -144,7 +144,7 @@ def _load_data(self) -> list[dict[str, Any]]: data = [] with h5py.File(self.filepath) as f: for k in sorted(f.keys()): - if f[k].attrs["split"] != self.split: + if f[k].attrs["split"] != self.splits[self.split]: continue for patch in sorted(f[k].keys()): From a6c71e6211135a5864d4c8e67ddff4fd4606589c Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 19 Apr 2024 11:59:16 -0500 Subject: [PATCH 18/18] update docstring --- torchgeo/datasets/quakeset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index edf231cacbd..025b5f4987b 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -84,6 +84,7 @@ def __init__( Raises: AssertionError: If ``split`` argument is invalid. DatasetNotFoundError: If dataset is not found and *download* is False. + ImportError: if h5py is not installed """ assert split in self.splits