diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index af3b8c3a836..e6456868166 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -77,6 +77,11 @@ Deep Globe Land Cover Challenge .. autoclass:: DeepGlobeLandCoverDataModule +Digital Typhoon +^^^^^^^^^^^^^^^ + +.. autoclass:: DigitalTyphoonDataModule + ETCI2021 Flood Detection ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8ed61a25970..729efd4d8a5 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -259,6 +259,12 @@ DFC2022 .. autoclass:: DFC2022 + +Digital Typhoon +^^^^^^^^^^^^^^^ + +.. autoclass:: DigitalTyphoon + ETCI2021 Flood Detection ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv index be4a131c8d7..03526ed5c5f 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/non_geo_datasets.csv @@ -11,6 +11,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `Kenya Crop Type`_,S,Sentinel-2,"CC-BY-SA-4.0","4,688",7,"3,035x2,016",10,MSI `DeepGlobe Land Cover`_,S,DigitalGlobe +Vivid,-,803,7,"2,448x2,448",0.5,RGB `DFC2022`_,S,Aerial,"CC-BY-4.0","3,981",15,"2,000x2,000",0.5,RGB +`Digital Typhoon`_,"C, R",Himawari,"CC-BY-4.0","189,364",8,512,5000,Infrared `ETCI2021 Flood Detection`_,S,Sentinel-1,-,"66,810",2,256x256,5--20,SAR `EuroSAT`_,C,Sentinel-2,"MIT","27,000",10,64x64,10,MSI `FAIR1M`_,OD,Gaofen/Google Earth,"CC-BY-NC-SA-3.0","15,000",37,"1,024x1,024",0.3--0.8,RGB diff --git a/tests/conf/digital_typhoon_id.yaml b/tests/conf/digital_typhoon_id.yaml new file mode 100644 index 00000000000..9eb1b3eaacd --- /dev/null +++ b/tests/conf/digital_typhoon_id.yaml @@ -0,0 +1,18 @@ +model: + class_path: RegressionTask + init_args: + model: "resnet18" + num_outputs: 1 + in_channels: 3 + loss: "mse" +data: + class_path: DigitalTyphoonDataModule + init_args: + batch_size: 1 + split_by: "typhoon_id" + dict_kwargs: + root: "tests/data/digital_typhoon" + download: true + min_feature_value: + wind: 10 + sequence_length: 3 diff --git a/tests/conf/digital_typhoon_time.yaml b/tests/conf/digital_typhoon_time.yaml new file mode 100644 index 00000000000..6049a2956fd --- /dev/null +++ b/tests/conf/digital_typhoon_time.yaml @@ -0,0 +1,18 @@ +model: + class_path: RegressionTask + init_args: + model: "resnet18" + num_outputs: 1 + in_channels: 3 + loss: "mse" +data: + class_path: DigitalTyphoonDataModule + init_args: + batch_size: 1 + split_by: "time" + dict_kwargs: + root: "tests/data/digital_typhoon" + download: true + min_feature_value: + wind: 10 + sequence_length: 3 diff --git a/tests/data/digital_typhoon/WP.tar.gz b/tests/data/digital_typhoon/WP.tar.gz new file mode 100644 index 00000000000..3d707e3a5b5 Binary files /dev/null and b/tests/data/digital_typhoon/WP.tar.gz differ diff --git a/tests/data/digital_typhoon/WP.tar.gzaa b/tests/data/digital_typhoon/WP.tar.gzaa new file mode 100644 index 00000000000..3d707e3a5b5 Binary files /dev/null and b/tests/data/digital_typhoon/WP.tar.gzaa differ diff --git a/tests/data/digital_typhoon/WP.tar.gzab b/tests/data/digital_typhoon/WP.tar.gzab new file mode 100644 index 00000000000..3d707e3a5b5 Binary files /dev/null and b/tests/data/digital_typhoon/WP.tar.gzab differ diff --git a/tests/data/digital_typhoon/WP/aux_data.csv b/tests/data/digital_typhoon/WP/aux_data.csv new file mode 100644 index 00000000000..81864dd076c --- /dev/null +++ b/tests/data/digital_typhoon/WP/aux_data.csv @@ -0,0 +1,26 @@ +id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct +0,0.h5,1979,12,25,6,3,-55.81114066899345,76.6995939240727,973.8743108424701,44.98399850309952,66,71,75,137,25,95,1,1,1.h5,mask_40,89.87979469874404 +0,1.h5,1979,12,25,7,3,-33.621634184914114,-25.860702927919903,903.8203398162416,6.3832427352565,230,28,61,111,72,4,1,0,2.h5,mask_40,55.86768840838465 +0,2.h5,1979,12,25,8,3,72.02964248591297,-47.48138416430828,982.76724331446,0.027966770724696666,342,76,23,337,49,19,1,0,3.h5,mask_49,55.18449786430531 +0,3.h5,1979,12,25,9,2,55.920575184851316,13.989913225833078,906.0181106433341,51.01642134825744,330,90,52,258,44,65,1,1,4.h5,mask_86,15.969129252036707 +0,4.h5,1979,12,25,10,2,-43.28994147714503,-161.94483446959413,903.9366550400755,16.7093617045847,242,62,99,132,63,0,1,1,5.h5,mask_66,70.21971067939033 +1,0.h5,1988,1,22,10,2,-33.37129190053344,-115.29637290040873,948.0758912152131,51.11399505734963,118,15,67,232,63,86,1,1,1.h5,mask_15,30.245077213336646 +1,1.h5,1988,1,22,11,2,74.93228846926493,70.74999801636073,910.1992664115785,60.8348103266534,266,41,67,48,44,16,1,0,2.h5,mask_90,42.30390416164944 +1,2.h5,1988,1,22,12,2,-27.931601464223597,-141.3019006863473,961.5531323907394,18.35497901874176,19,61,24,295,50,26,1,1,3.h5,mask_67,60.35785307941444 +1,3.h5,1988,1,22,13,3,-27.166703710913154,-27.976214499674484,904.1165949703977,9.081723951290567,144,43,66,22,32,48,0,1,4.h5,mask_3,80.04417033291257 +1,4.h5,1988,1,22,14,2,47.51657289770864,-138.58539565379158,950.9654977977864,86.18819130981862,175,75,89,42,19,70,0,1,5.h5,mask_96,0.44001778199053154 +2,0.h5,1998,8,23,22,2,71.11037770397022,-170.05883586527145,902.757696015989,64.83605229043086,308,32,54,249,94,13,1,0,1.h5,mask_87,97.96789767456457 +2,1.h5,1998,8,23,23,2,-45.9880469141837,-153.85203885662787,956.1578736191437,95.77226625568278,230,17,58,214,72,21,1,0,2.h5,mask_66,48.1513473689529 +2,2.h5,1998,8,24,0,4,-88.778300647409,-78.43060469893915,958.764771469677,17.97662971655637,127,41,19,138,89,36,1,1,3.h5,mask_57,76.31799924098371 +2,3.h5,1998,8,24,1,2,-49.56689955810804,-120.3389762632577,986.4933451650326,49.259894810485605,333,90,28,51,45,99,1,0,4.h5,mask_92,65.60333971250041 +2,4.h5,1998,8,24,2,3,-52.55231579306487,80.06217230886841,997.4333837891787,48.25976623703225,63,7,13,71,55,58,1,1,5.h5,mask_73,50.634737551399034 +3,0.h5,1997,4,24,16,4,-61.81374526076493,60.62026564332362,900.1093638487514,94.66595722320622,189,70,67,249,12,58,0,1,1.h5,mask_93,99.77561346276104 +3,1.h5,1997,4,24,17,3,35.596382297289026,-117.20301531275722,925.1366339770796,34.46028512732848,55,55,74,11,0,49,1,1,2.h5,mask_11,5.726401727423658 +3,2.h5,1997,4,24,18,1,68.16880747309938,30.42194122117013,955.7265683876137,96.55057639044118,217,22,60,6,18,9,1,1,3.h5,mask_63,58.982331802755375 +3,3.h5,1997,4,24,19,3,-5.491619122910365,141.83240318855258,922.5486496962513,89.2199247408618,49,26,14,245,95,84,1,0,4.h5,mask_38,76.01607012923168 +3,4.h5,1997,4,24,20,4,4.052162855787202,21.732867986138842,990.5791999912764,98.40094253121877,158,86,11,28,11,81,0,0,5.h5,mask_12,75.84036894650622 +4,0.h5,1984,6,16,14,3,53.238650326925125,-54.63854263302531,934.2198641027621,18.697921579520305,212,16,42,91,90,56,1,1,1.h5,mask_72,78.93081269669048 +4,1.h5,1984,6,16,15,2,-56.222689844694024,-6.8726887962189664,912.6113238303491,61.286246561868666,60,81,2,198,64,76,1,0,2.h5,mask_64,24.039173626000288 +4,2.h5,1984,6,16,16,2,-4.285643464886363,95.66534210331434,962.0580147775602,86.01251389789185,281,81,5,228,18,94,0,0,3.h5,mask_66,89.89080488339964 +4,3.h5,1984,6,16,17,2,89.15893201203946,124.94143678744513,997.342814284227,84.00590505469005,242,28,61,132,80,29,0,0,4.h5,mask_77,4.839048143310343 +4,4.h5,1984,6,16,18,1,-46.31233638346047,21.77073986978661,932.8378121656477,26.18973887839292,294,76,57,252,99,27,1,0,5.h5,mask_65,89.74882055138497 diff --git a/tests/data/digital_typhoon/WP/image/0/0.h5 b/tests/data/digital_typhoon/WP/image/0/0.h5 new file mode 100644 index 00000000000..235ea1897f3 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/0/0.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/0/1.h5 b/tests/data/digital_typhoon/WP/image/0/1.h5 new file mode 100644 index 00000000000..98ece1b9351 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/0/1.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/0/2.h5 b/tests/data/digital_typhoon/WP/image/0/2.h5 new file mode 100644 index 00000000000..40cd7317d40 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/0/2.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/0/3.h5 b/tests/data/digital_typhoon/WP/image/0/3.h5 new file mode 100644 index 00000000000..6f2be498621 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/0/3.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/0/4.h5 b/tests/data/digital_typhoon/WP/image/0/4.h5 new file mode 100644 index 00000000000..731298cdd32 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/0/4.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/1/0.h5 b/tests/data/digital_typhoon/WP/image/1/0.h5 new file mode 100644 index 00000000000..d6009711570 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/1/0.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/1/1.h5 b/tests/data/digital_typhoon/WP/image/1/1.h5 new file mode 100644 index 00000000000..3f636ec3afc Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/1/1.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/1/2.h5 b/tests/data/digital_typhoon/WP/image/1/2.h5 new file mode 100644 index 00000000000..71acdc32c82 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/1/2.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/1/3.h5 b/tests/data/digital_typhoon/WP/image/1/3.h5 new file mode 100644 index 00000000000..65b76ff2f32 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/1/3.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/1/4.h5 b/tests/data/digital_typhoon/WP/image/1/4.h5 new file mode 100644 index 00000000000..df52fb412fd Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/1/4.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/2/0.h5 b/tests/data/digital_typhoon/WP/image/2/0.h5 new file mode 100644 index 00000000000..d391fab0d71 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/2/0.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/2/1.h5 b/tests/data/digital_typhoon/WP/image/2/1.h5 new file mode 100644 index 00000000000..7b80f60255b Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/2/1.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/2/2.h5 b/tests/data/digital_typhoon/WP/image/2/2.h5 new file mode 100644 index 00000000000..c108210a0e5 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/2/2.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/2/3.h5 b/tests/data/digital_typhoon/WP/image/2/3.h5 new file mode 100644 index 00000000000..2f1f14b9a51 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/2/3.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/2/4.h5 b/tests/data/digital_typhoon/WP/image/2/4.h5 new file mode 100644 index 00000000000..4e0fcb578fd Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/2/4.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/3/0.h5 b/tests/data/digital_typhoon/WP/image/3/0.h5 new file mode 100644 index 00000000000..d04cc4f79c0 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/3/0.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/3/1.h5 b/tests/data/digital_typhoon/WP/image/3/1.h5 new file mode 100644 index 00000000000..65ac943c680 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/3/1.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/3/2.h5 b/tests/data/digital_typhoon/WP/image/3/2.h5 new file mode 100644 index 00000000000..1ab8b197980 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/3/2.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/3/3.h5 b/tests/data/digital_typhoon/WP/image/3/3.h5 new file mode 100644 index 00000000000..9fcab04f7d1 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/3/3.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/3/4.h5 b/tests/data/digital_typhoon/WP/image/3/4.h5 new file mode 100644 index 00000000000..ccd6e248d55 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/3/4.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/4/0.h5 b/tests/data/digital_typhoon/WP/image/4/0.h5 new file mode 100644 index 00000000000..64fc6bd1b53 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/4/0.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/4/1.h5 b/tests/data/digital_typhoon/WP/image/4/1.h5 new file mode 100644 index 00000000000..1d66c74238f Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/4/1.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/4/2.h5 b/tests/data/digital_typhoon/WP/image/4/2.h5 new file mode 100644 index 00000000000..7353050bcd1 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/4/2.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/4/3.h5 b/tests/data/digital_typhoon/WP/image/4/3.h5 new file mode 100644 index 00000000000..f7185764d80 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/4/3.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/4/4.h5 b/tests/data/digital_typhoon/WP/image/4/4.h5 new file mode 100644 index 00000000000..3fde7973bb3 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/4/4.h5 differ diff --git a/tests/data/digital_typhoon/WP/metadata/0.csv b/tests/data/digital_typhoon/WP/metadata/0.csv new file mode 100644 index 00000000000..df1c40f5fba --- /dev/null +++ b/tests/data/digital_typhoon/WP/metadata/0.csv @@ -0,0 +1,6 @@ +id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct +0,0.h5,1979,12,25,6,3,-55.81114066899345,76.6995939240727,973.8743108424701,44.98399850309952,66,71,75,137,25,95,1,1,1.h5,mask_40,89.87979469874404 +0,1.h5,1979,12,25,7,3,-33.621634184914114,-25.860702927919903,903.8203398162416,6.3832427352565,230,28,61,111,72,4,1,0,2.h5,mask_40,55.86768840838465 +0,2.h5,1979,12,25,8,3,72.02964248591297,-47.48138416430828,982.76724331446,0.027966770724696666,342,76,23,337,49,19,1,0,3.h5,mask_49,55.18449786430531 +0,3.h5,1979,12,25,9,2,55.920575184851316,13.989913225833078,906.0181106433341,51.01642134825744,330,90,52,258,44,65,1,1,4.h5,mask_86,15.969129252036707 +0,4.h5,1979,12,25,10,2,-43.28994147714503,-161.94483446959413,903.9366550400755,16.7093617045847,242,62,99,132,63,0,1,1,5.h5,mask_66,70.21971067939033 diff --git a/tests/data/digital_typhoon/WP/metadata/1.csv b/tests/data/digital_typhoon/WP/metadata/1.csv new file mode 100644 index 00000000000..3dd6e71bcc1 --- /dev/null +++ b/tests/data/digital_typhoon/WP/metadata/1.csv @@ -0,0 +1,6 @@ +id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct +1,0.h5,1988,1,22,10,2,-33.37129190053344,-115.29637290040873,948.0758912152131,51.11399505734963,118,15,67,232,63,86,1,1,1.h5,mask_15,30.245077213336646 +1,1.h5,1988,1,22,11,2,74.93228846926493,70.74999801636073,910.1992664115785,60.8348103266534,266,41,67,48,44,16,1,0,2.h5,mask_90,42.30390416164944 +1,2.h5,1988,1,22,12,2,-27.931601464223597,-141.3019006863473,961.5531323907394,18.35497901874176,19,61,24,295,50,26,1,1,3.h5,mask_67,60.35785307941444 +1,3.h5,1988,1,22,13,3,-27.166703710913154,-27.976214499674484,904.1165949703977,9.081723951290567,144,43,66,22,32,48,0,1,4.h5,mask_3,80.04417033291257 +1,4.h5,1988,1,22,14,2,47.51657289770864,-138.58539565379158,950.9654977977864,86.18819130981862,175,75,89,42,19,70,0,1,5.h5,mask_96,0.44001778199053154 diff --git a/tests/data/digital_typhoon/WP/metadata/2.csv b/tests/data/digital_typhoon/WP/metadata/2.csv new file mode 100644 index 00000000000..9f43c8edca2 --- /dev/null +++ b/tests/data/digital_typhoon/WP/metadata/2.csv @@ -0,0 +1,6 @@ +id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct +2,0.h5,1998,8,23,22,2,71.11037770397022,-170.05883586527145,902.757696015989,64.83605229043086,308,32,54,249,94,13,1,0,1.h5,mask_87,97.96789767456457 +2,1.h5,1998,8,23,23,2,-45.9880469141837,-153.85203885662787,956.1578736191437,95.77226625568278,230,17,58,214,72,21,1,0,2.h5,mask_66,48.1513473689529 +2,2.h5,1998,8,24,0,4,-88.778300647409,-78.43060469893915,958.764771469677,17.97662971655637,127,41,19,138,89,36,1,1,3.h5,mask_57,76.31799924098371 +2,3.h5,1998,8,24,1,2,-49.56689955810804,-120.3389762632577,986.4933451650326,49.259894810485605,333,90,28,51,45,99,1,0,4.h5,mask_92,65.60333971250041 +2,4.h5,1998,8,24,2,3,-52.55231579306487,80.06217230886841,997.4333837891787,48.25976623703225,63,7,13,71,55,58,1,1,5.h5,mask_73,50.634737551399034 diff --git a/tests/data/digital_typhoon/WP/metadata/3.csv b/tests/data/digital_typhoon/WP/metadata/3.csv new file mode 100644 index 00000000000..6144192b6ea --- /dev/null +++ b/tests/data/digital_typhoon/WP/metadata/3.csv @@ -0,0 +1,6 @@ +id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct +3,0.h5,1997,4,24,16,4,-61.81374526076493,60.62026564332362,900.1093638487514,94.66595722320622,189,70,67,249,12,58,0,1,1.h5,mask_93,99.77561346276104 +3,1.h5,1997,4,24,17,3,35.596382297289026,-117.20301531275722,925.1366339770796,34.46028512732848,55,55,74,11,0,49,1,1,2.h5,mask_11,5.726401727423658 +3,2.h5,1997,4,24,18,1,68.16880747309938,30.42194122117013,955.7265683876137,96.55057639044118,217,22,60,6,18,9,1,1,3.h5,mask_63,58.982331802755375 +3,3.h5,1997,4,24,19,3,-5.491619122910365,141.83240318855258,922.5486496962513,89.2199247408618,49,26,14,245,95,84,1,0,4.h5,mask_38,76.01607012923168 +3,4.h5,1997,4,24,20,4,4.052162855787202,21.732867986138842,990.5791999912764,98.40094253121877,158,86,11,28,11,81,0,0,5.h5,mask_12,75.84036894650622 diff --git a/tests/data/digital_typhoon/WP/metadata/4.csv b/tests/data/digital_typhoon/WP/metadata/4.csv new file mode 100644 index 00000000000..c2267e37fe9 --- /dev/null +++ b/tests/data/digital_typhoon/WP/metadata/4.csv @@ -0,0 +1,6 @@ +id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct +4,0.h5,1984,6,16,14,3,53.238650326925125,-54.63854263302531,934.2198641027621,18.697921579520305,212,16,42,91,90,56,1,1,1.h5,mask_72,78.93081269669048 +4,1.h5,1984,6,16,15,2,-56.222689844694024,-6.8726887962189664,912.6113238303491,61.286246561868666,60,81,2,198,64,76,1,0,2.h5,mask_64,24.039173626000288 +4,2.h5,1984,6,16,16,2,-4.285643464886363,95.66534210331434,962.0580147775602,86.01251389789185,281,81,5,228,18,94,0,0,3.h5,mask_66,89.89080488339964 +4,3.h5,1984,6,16,17,2,89.15893201203946,124.94143678744513,997.342814284227,84.00590505469005,242,28,61,132,80,29,0,0,4.h5,mask_77,4.839048143310343 +4,4.h5,1984,6,16,18,1,-46.31233638346047,21.77073986978661,932.8378121656477,26.18973887839292,294,76,57,252,99,27,1,0,5.h5,mask_65,89.74882055138497 diff --git a/tests/data/digital_typhoon/data.py b/tests/data/digital_typhoon/data.py new file mode 100644 index 00000000000..6636a7cdbfa --- /dev/null +++ b/tests/data/digital_typhoon/data.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil + +import h5py +import numpy as np +import pandas as pd +from torchvision.datasets.utils import calculate_md5 + +# Define the root directory +root = 'WP' +IMAGE_SIZE = 32 +NUM_TYHOON_IDS = 5 +NUM_IMAGES_PER_ID = 5 +CHUNK_SIZE = 2**12 + +# If the root directory exists, remove it +if os.path.exists(root): + shutil.rmtree(root) + +# Create the 'image' and 'metadata' directories +os.makedirs(os.path.join(root, 'image')) +os.makedirs(os.path.join(root, 'metadata')) + +# For each typhoon_id +all_dfs = [] +for typhoon_id in range(NUM_TYHOON_IDS): + # Create a directory under 'root/image/typhoon_id/' + os.makedirs(os.path.join(root, 'image', str(typhoon_id)), exist_ok=True) + + # Create dummy .h5 files + image_paths_per_typhoon = [] + for image_id in range(NUM_IMAGES_PER_ID): + image_file_name = f'{image_id}.h5' + with h5py.File( + os.path.join(root, 'image', str(typhoon_id), image_file_name), 'w' + ) as hf: + hf.create_dataset('Infrared', data=np.random.rand(IMAGE_SIZE, IMAGE_SIZE)) + image_paths_per_typhoon.append(image_file_name) + + start_time = pd.Timestamp( + year=np.random.randint(1978, 2022), + month=np.random.randint(1, 13), + day=np.random.randint(1, 29), + hour=np.random.randint(0, 24), + ) + times = pd.date_range(start=start_time, periods=NUM_IMAGES_PER_ID, freq='H') + df = pd.DataFrame( + { + 'id': np.repeat(typhoon_id, NUM_IMAGES_PER_ID), + 'image_path': image_paths_per_typhoon, + 'year': times.year, + 'month': times.month, + 'day': times.day, + 'hour': times.hour, + 'grade': np.random.randint(1, 5, NUM_IMAGES_PER_ID), + 'lat': np.random.uniform(-90, 90, NUM_IMAGES_PER_ID), + 'lng': np.random.uniform(-180, 180, NUM_IMAGES_PER_ID), + 'pressure': np.random.uniform(900, 1000, NUM_IMAGES_PER_ID), + 'wind': np.random.uniform(0, 100, NUM_IMAGES_PER_ID), + 'dir50': np.random.randint(0, 360, NUM_IMAGES_PER_ID), + 'long50': np.random.randint(0, 100, NUM_IMAGES_PER_ID), + 'short50': np.random.randint(0, 100, NUM_IMAGES_PER_ID), + 'dir30': np.random.randint(0, 360, NUM_IMAGES_PER_ID), + 'long30': np.random.randint(0, 100, NUM_IMAGES_PER_ID), + 'short30': np.random.randint(0, 100, NUM_IMAGES_PER_ID), + 'landfall': np.random.randint(0, 2, NUM_IMAGES_PER_ID), + 'intp': np.random.randint(0, 2, NUM_IMAGES_PER_ID), + 'file_1': [f'{idx}.h5' for idx in range(1, NUM_IMAGES_PER_ID + 1)], + 'mask_1': [ + 'mask_' + str(i) for i in np.random.randint(1, 100, NUM_IMAGES_PER_ID) + ], + 'mask_1_pct': np.random.uniform(0, 100, NUM_IMAGES_PER_ID), + } + ) + + # Save the DataFrame to correspoding typhoon id as metadata + df.to_csv(os.path.join(root, 'metadata', f'{typhoon_id}.csv'), index=False) + + all_dfs.append(df) + +# Save the aux_data.csv +aux_data = pd.concat(all_dfs) +aux_data.to_csv(os.path.join(root, 'aux_data.csv'), index=False) + + +# Create tarball +shutil.make_archive(root, 'gztar', '.', root) + +# simulate multiple tar files +path = f'{root}.tar.gz' +paths = [] +with open(path, 'rb') as f: + # Write the entire tarball to gzaa + split = f'{path}aa' + with open(split, 'wb') as g: + g.write(f.read()) + paths.append(split) + +# Create gzab as a copy of gzaa +shutil.copy2(f'{path}aa', f'{path}ab') +paths.append(f'{path}ab') + + +# Calculate the md5sum of the tar file +for path in paths: + print(f'{path}: {calculate_md5(path)}') diff --git a/tests/datamodules/test_digital_typhoon.py b/tests/datamodules/test_digital_typhoon.py new file mode 100644 index 00000000000..0ecd85f5ec7 --- /dev/null +++ b/tests/datamodules/test_digital_typhoon.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Test Digital Typhoon Datamodule.""" + +import os + +import pytest + +from torchgeo.datamodules import DigitalTyphoonDataModule +from torchgeo.datasets.digital_typhoon import DigitalTyphoon, _SampleSequenceDict + +pytest.importorskip('h5py', minversion='3.6') + + +class TestDigitalTyphoonDataModule: + def test_invalid_param_config(self) -> None: + with pytest.raises(AssertionError, match='Please choose from'): + DigitalTyphoonDataModule( + root=os.path.join('tests', 'data', 'digital_typhoon'), + split_by='invalid', + batch_size=2, + num_workers=0, + ) + + @pytest.mark.parametrize('split_by', ['time', 'typhoon_id']) + def test_split_dataset(self, split_by: str) -> None: + dm = DigitalTyphoonDataModule( + root=os.path.join('tests', 'data', 'digital_typhoon'), + split_by=split_by, + batch_size=2, + num_workers=0, + ) + dataset = DigitalTyphoon(root=os.path.join('tests', 'data', 'digital_typhoon')) + train_indices, val_indices = dm._split_dataset(dataset.sample_sequences) + train_sequences, val_sequences = ( + [dataset.sample_sequences[i] for i in train_indices], + [dataset.sample_sequences[i] for i in val_indices], + ) + + if split_by == 'time': + + def find_max_time_per_id( + split_sequences: list[_SampleSequenceDict], + ) -> dict[str, int]: + # Find the maximum value of each id in train_sequences + max_values: dict[str, int] = {} + for seq in split_sequences: + id: str = str(seq['id']) + value: int = max(seq['seq_id']) + if id not in max_values or value > max_values[id]: + max_values[id] = value + return max_values + + train_max_values = find_max_time_per_id(train_sequences) + val_max_values = find_max_time_per_id(val_sequences) + # Assert that each max value in train_max_values is lower + # than in val_max_values for each key id + for id, max_value in train_max_values.items(): + assert ( + id not in val_max_values or max_value < val_max_values[id] + ), f'Max value for id {id} in train is not lower than in validation.' + else: + train_ids = {seq['id'] for seq in train_sequences} + val_ids = {seq['id'] for seq in val_sequences} + + # Assert that the intersection between train_ids and val_ids is empty + assert ( + len(train_ids & val_ids) == 0 + ), 'Train and validation datasets have overlapping ids.' diff --git a/tests/datasets/test_digital_typhoon.py b/tests/datasets/test_digital_typhoon.py new file mode 100644 index 00000000000..c3df283ec35 --- /dev/null +++ b/tests/datasets/test_digital_typhoon.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch + +from torchgeo.datasets import DatasetNotFoundError, DigitalTyphoon + +pytest.importorskip('h5py', minversion='3.6') + + +class TestDigitalTyphoon: + @pytest.fixture( + params=[ + (3, {'wind': 0}, {'pressure': 1500}), + (3, {'pressure': 0}, {'wind': 100}), + ] + ) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> DigitalTyphoon: + sequence_length, min_features, max_features = request.param + + url = os.path.join('tests', 'data', 'digital_typhoon', 'WP.tar.gz{0}') + monkeypatch.setattr(DigitalTyphoon, 'url', url) + + md5sums = { + 'aa': '692ea3796c9bc9ef1e0ab6f2b8bc51ad', + 'ab': '692ea3796c9bc9ef1e0ab6f2b8bc51ad', + } + monkeypatch.setattr(DigitalTyphoon, 'md5sums', md5sums) + root = tmp_path + + transforms = nn.Identity() + return DigitalTyphoon( + root=root, + sequence_length=sequence_length, + min_feature_value=min_features, + max_feature_value=max_features, + transforms=transforms, + download=True, + checksum=True, + ) + + def test_len(self, dataset: DigitalTyphoon) -> None: + assert len(dataset) == 15 + + @pytest.mark.parametrize('index', [0, 1]) + def test_getitem(self, dataset: DigitalTyphoon, index: int) -> None: + x = dataset[index] + assert isinstance(x, dict) + assert isinstance(x['image'], torch.Tensor) + assert x['image'].min() >= 0 and x['image'].max() <= 1 + assert isinstance(x['label'], torch.Tensor) + + def test_already_downloaded(self, dataset: DigitalTyphoon) -> None: + DigitalTyphoon(root=dataset.root) + + def test_not_yet_extracted(self, tmp_path: Path) -> None: + root = os.path.join('tests', 'data', 'digital_typhoon') + filenames = ['WP.tar.gzaa', 'WP.tar.gzab'] + for filename in filenames: + shutil.copyfile(os.path.join(root, filename), tmp_path / filename) + DigitalTyphoon(root=str(tmp_path)) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + DigitalTyphoon(root=str(tmp_path)) + + def test_plot(self, dataset: DigitalTyphoon) -> None: + dataset.plot(dataset[0], suptitle='Test') + plt.close() + + sample = dataset[0] + sample['prediction'] = sample['label'] + dataset.plot(sample) + plt.close() diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 00c9da65321..f4089283242 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -59,12 +59,20 @@ def create_model(*args: Any, **kwargs: Any) -> Module: return RegressionTestModel(**kwargs) @pytest.mark.parametrize( - 'name', ['cowc_counting', 'cyclone', 'sustainbench_crop_yield', 'skippd'] + 'name', + [ + 'cowc_counting', + 'cyclone', + 'digital_typhoon_id', + 'digital_typhoon_time', + 'sustainbench_crop_yield', + 'skippd', + ], ) def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: - if name == 'skippd': + if name in ['skippd', 'digital_typhoon_id', 'digital_typhoon_time']: pytest.importorskip('h5py', minversion='3.6') config = os.path.join('tests', 'conf', name + '.yaml') diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index c6bd6fc0a20..b0163963713 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -11,6 +11,7 @@ from .cowc import COWCCountingDataModule from .cyclone import TropicalCycloneDataModule from .deepglobelandcover import DeepGlobeLandCoverDataModule +from .digital_typhoon import DigitalTyphoonDataModule from .etci2021 import ETCI2021DataModule from .eurosat import EuroSAT100DataModule, EuroSATDataModule, EuroSATSpatialDataModule from .fair1m import FAIR1MDataModule @@ -70,6 +71,7 @@ 'ChaBuDDataModule', 'COWCCountingDataModule', 'DeepGlobeLandCoverDataModule', + 'DigitalTyphoonDataModule', 'ETCI2021DataModule', 'EuroSATDataModule', 'EuroSATSpatialDataModule', diff --git a/torchgeo/datamodules/digital_typhoon.py b/torchgeo/datamodules/digital_typhoon.py new file mode 100644 index 00000000000..11874aba8de --- /dev/null +++ b/torchgeo/datamodules/digital_typhoon.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Digital Typhoon Data Module.""" + +import copy +from collections import defaultdict +from typing import Any + +from torch.utils.data import Subset + +from ..datasets import DigitalTyphoon +from ..datasets.digital_typhoon import _SampleSequenceDict +from .geo import NonGeoDataModule +from .utils import group_shuffle_split + + +class DigitalTyphoonDataModule(NonGeoDataModule): + """Digital Typhoon Data Module.""" + + valid_split_types = ('time', 'typhoon_id') + + def __init__( + self, + split_by: str = 'time', + batch_size: int = 64, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a new DigitalTyphoonDataModule instance. + + Args: + split_by: Either 'time' or 'typhoon_id', which decides how to split + the dataset for train, val, test + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.DigitalTyphoon`. + + """ + super().__init__(DigitalTyphoon, batch_size, num_workers, **kwargs) + + assert ( + split_by in self.valid_split_types + ), f'Please choose from {self.valid_split_types}' + self.split_by = split_by + + def _split_dataset( + self, sample_sequences: list[_SampleSequenceDict] + ) -> tuple[list[int], list[int]]: + """Split dataset into two parts. + + Args: + sample_sequences: List of sample sequence dictionaries to be split + + Returns: + a tuple of the subset datasets + """ + if self.split_by == 'time': + # split dataset such that only unseen future time steps of storms + # are contained in validation + grouped_sequences = defaultdict(list) + for idx, seq in enumerate(sample_sequences): + grouped_sequences[seq['id']].append((idx, seq['seq_id'])) + + train_indices = [] + val_indices = [] + + for id, sequences in grouped_sequences.items(): + split_idx = int(len(sequences) * 0.8) + train_sequences = sequences[:split_idx] + val_sequences = sequences[split_idx:] + train_indices.extend([idx for idx, _ in train_sequences]) + val_indices.extend([idx for idx, _ in val_sequences]) + + else: + # split dataset such that the id of storms is mutually exclusive + train_indices, val_indices = group_shuffle_split( + [x['id'] for x in sample_sequences], train_size=0.8, random_state=0 + ) + + return train_indices, val_indices + + def setup(self, stage: str) -> None: + """Set up datasets. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + self.dataset = DigitalTyphoon(**self.kwargs) + + all_sample_sequences = copy.deepcopy(self.dataset.sample_sequences) + + train_indices, test_indices = self._split_dataset(self.dataset.sample_sequences) + + if stage in ['fit', 'validate']: + # Randomly split train into train and validation sets + index_mapping = { + new_index: original_index + for new_index, original_index in enumerate(train_indices) + } + train_sequences = [all_sample_sequences[i] for i in train_indices] + train_indices, val_indices = self._split_dataset(train_sequences) + train_indices = [index_mapping[i] for i in train_indices] + val_indices = [index_mapping[i] for i in val_indices] + + # Create train val subset dataset + self.train_dataset = Subset(self.dataset, train_indices) + self.val_dataset = Subset(self.dataset, val_indices) + + if stage in ['test']: + self.test_dataset = Subset(self.dataset, test_indices) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 0510dd9aa54..6d6701af353 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -34,6 +34,7 @@ from .cyclone import TropicalCyclone from .deepglobelandcover import DeepGlobeLandCover from .dfc2022 import DFC2022 +from .digital_typhoon import DigitalTyphoon from .eddmaps import EDDMapS from .enviroatlas import EnviroAtlas from .errors import DatasetNotFoundError, DependencyNotFoundError, RGBBandsMissingError @@ -209,6 +210,7 @@ 'CV4AKenyaCropType', 'DeepGlobeLandCover', 'DFC2022', + 'DigitalTyphoon', 'EnviroAtlas', 'ETCI2021', 'EuroSAT', diff --git a/torchgeo/datasets/digital_typhoon.py b/torchgeo/datasets/digital_typhoon.py new file mode 100644 index 00000000000..42bb4caa1bd --- /dev/null +++ b/torchgeo/datasets/digital_typhoon.py @@ -0,0 +1,458 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Digital Typhoon dataset.""" + +import glob +import os +import tarfile +from collections.abc import Callable, Sequence +from typing import Any, ClassVar, TypedDict + +import matplotlib.pyplot as plt +import pandas as pd +import torch +from matplotlib.figure import Figure +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path, download_url, lazy_import, percentile_normalization + + +class _SampleSequenceDict(TypedDict): + """Sample sequence dictionary.""" + + id: str + seq_id: list[int] + + +class DigitalTyphoon(NonGeoDataset): + """Digital Typhoon Dataset for Analysis Task. + + This dataset contains typhoon-centered images, derived from hourly infrared channel + images captured by meteorological satellites. It incorporates data from multiple + generations of the Himawari weather satellite, dating back to 1978. These images + have been transformed into brightness temperatures and adjusted for varying + satellite sensor readings, yielding a consistent spatio-temporal dataset that + covers over four decades. + + See `the Digital Typhoon website + `_ + for more information about the dataset. + + Dataset features: + + * infrared channel images from the Himawari weather satellite (512x512 px) + at 5km spatial resolution + * auxiliary features such as wind speed, pressure, and more that can be used + for regression or classification tasks + * 1,099 typhoons and 189,364 images + + Dataset format: + + * hdf5 files containing the infrared channel images + * .csv files containing the metadata for each image + + If you use this dataset in your research, please cite the following papers: + + * https://doi.org/10.20783/DIAS.664 + + .. versionadded:: 0.6 + """ + + valid_tasks = ('classification', 'regression') + aux_file_name = 'aux_data.csv' + + valid_features = ( + 'year', + 'month', + 'day', + 'hour', + 'grade', + 'lat', + 'lng', + 'pressure', + 'wind', + 'dir50', + 'long50', + 'short50', + 'dir30', + 'long30', + 'short30', + 'landfall', + 'intp', + ) + + url = 'https://hf.co/datasets/torchgeo/digital_typhoon/resolve/cf2f9ef89168d31cb09e42993d35b068688fe0df/WP.tar.gz{0}' + + md5sums: ClassVar[dict[str, str]] = { + 'aa': '3af98052aed17e0ddb1e94caca2582e2', + 'ab': '2c5d25455ac8aef1de33fe6456ab2c8d', + } + + min_input_clamp = 170.0 + max_input_clamp = 300.0 + + data_root = 'WP' + + def __init__( + self, + root: Path = 'data', + task: str = 'regression', + features: Sequence[str] = ['wind'], + targets: Sequence[str] = ['wind'], + sequence_length: int = 3, + min_feature_value: dict[str, float] | None = None, + max_feature_value: dict[str, float] | None = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new Digital Typhoon dataset instance. + + Args: + root: root directory where dataset can be found + task: whether to load 'regression' or 'classification' labels + features: which auxiliary features to return + targets: which auxiliary features to use as targets + sequence_length: length of the sequence to return + min_feature_value: minimum value for each feature + max_feature_value: maximum value for each feature + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If any arguments are invalid. + DatasetNotFoundError: If dataset is not found and *download* is False. + DependencyNotFoundError: If h5py is not installed. + """ + lazy_import('h5py') + self.root = root + self.transforms = transforms + self.download = download + self.checksum = checksum + self.sequence_length = sequence_length + + self.min_feature_value = min_feature_value + self.max_feature_value = max_feature_value + + assert ( + task in self.valid_tasks + ), f'Please choose one of {self.valid_tasks}, you provided {task}.' + self.task = task + + assert set(features).issubset(set(self.valid_features)) + self.features = features + + assert set(targets).issubset(set(self.valid_features)) + self.targets = targets + + self._verify() + + self.aux_df = pd.read_csv( + os.path.join(root, self.data_root, self.aux_file_name) + ) + self.aux_df['datetime'] = pd.to_datetime( + self.aux_df[['year', 'month', 'day', 'hour']] + ) + + self.aux_df = self.aux_df.sort_values(['year', 'month', 'day', 'hour']) + self.aux_df['seq_id'] = self.aux_df.groupby(['id']).cumcount() + + self.aux_df.columns = [str(col) for col in self.aux_df.columns] + + # Compute the hour difference between consecutive images per typhoon id + self.aux_df['hour_diff_consecutive'] = ( + self.aux_df.sort_values(['id', 'datetime']) + .groupby('id')['datetime'] + .diff() + .dt.total_seconds() + / 3600 + ) + + # Compute the hour difference between the first and second entry + self.aux_df['hour_diff_to_next'] = ( + self.aux_df.groupby('id')['datetime'] + .shift(-1) + .sub(self.aux_df['datetime']) + .abs() + .dt.total_seconds() + / 3600 + ) + + self.aux_df['hour_diff'] = self.aux_df['hour_diff_consecutive'].combine_first( + self.aux_df['hour_diff_to_next'] + ) + self.aux_df.drop( + ['hour_diff_consecutive', 'hour_diff_to_next'], axis=1, inplace=True + ) + + # 0 hour difference is for the last time step of each typhoon sequence and want + # to keep only images that have max 1 hour difference + self.aux_df = self.aux_df[self.aux_df['hour_diff'] <= 1] + # Filter out all ids that only have less than sequence_length entries + self.aux_df = self.aux_df.groupby('id').filter( + lambda x: len(x) >= self.sequence_length + ) + + # Filter aux_df according to min_target_value + if self.min_feature_value is not None: + for feature, min_value in self.min_feature_value.items(): + self.aux_df = self.aux_df[self.aux_df[feature] >= min_value] + + # Filter aux_df according to max_target_value + if self.max_feature_value is not None: + for feature, max_value in self.max_feature_value.items(): + self.aux_df = self.aux_df[self.aux_df[feature] <= max_value] + + # collect target mean and std for each target + self.target_mean: dict[str, float] = self.aux_df[self.targets].mean().to_dict() + self.target_std: dict[str, float] = self.aux_df[self.targets].std().to_dict() + + def _get_subsequences(df: pd.DataFrame, k: int) -> list[dict[str, list[int]]]: + """Generate all possible subsequences of length k for a given group. + + Args: + df: grouped dataframe of a single typhoon + k: length of the subsequences to generate + + Returns: + list of all possible subsequences of length k for a given typhoon id + """ + min_seq_id = df['seq_id'].min() + max_seq_id = df['seq_id'].max() + + # generate possible subsquences of length k for group + subsequences = [ + {'id': df['id'].iloc[0], 'seq_id': list(range(i, i + k))} + for i in range(min_seq_id, max_seq_id - k + 2) + ] + return [ + subseq + for subseq in subsequences + if set(subseq['seq_id']).issubset(df['seq_id']) + ] + + self.sample_sequences: list[_SampleSequenceDict] = [ + item + for sublist in self.aux_df.groupby('id')[['seq_id', 'id']] + .apply(_get_subsequences, k=self.sequence_length) + .tolist() + for item in sublist + ] + + def __getitem__(self, index: int) -> dict[str, Any]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data, labels, and metadata at that index + """ + sample_entry = self.sample_sequences[index] + sample_df = self.aux_df[ + (self.aux_df['id'] == sample_entry['id']) + & (self.aux_df['seq_id'].isin(sample_entry['seq_id'])) + ] + + sample = {'image': self._load_image(sample_df)} + # load features of the last image in the sequence + sample.update( + self._load_features( + os.path.join( + self.root, + self.data_root, + 'metadata', + str(sample_df.iloc[-1]['id']) + '.csv', + ), + sample_df.iloc[-1]['image_path'], + ) + ) + + # torchgeo expects a single label + sample['label'] = torch.Tensor([sample[target] for target in self.targets]) + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.sample_sequences) + + def _load_image(self, sample_df: pd.DataFrame) -> Tensor: + """Load a single image. + + Args: + sample_df: df holding all information necessary to load the + consecutive images in the sequence + + Returns: + concatenation of all images in the sequence over channel dimension + """ + + def load_image_tensor(id: str, filepath: str) -> Tensor: + """Load a single image tensor from a h5 file. + + Args: + id: typhoon id + filepath: path to the h5 file + + Returns: + image tensor + """ + h5py = lazy_import('h5py') + + full_path = os.path.join(self.root, self.data_root, 'image', id, filepath) + with h5py.File(full_path, 'r') as h5f: + # tensor with added channel dimension + tensor = torch.from_numpy(h5f['Infrared'][:]).unsqueeze(0) + + # follow normalization procedure + # https://github.com/kitamoto-lab/benchmarks/blob/1bdbefd7c570cb1bdbdf9e09f9b63f7c22bbdb27/analysis/regression/FrameDatamodule.py#L94 + tensor = torch.clamp(tensor, self.min_input_clamp, self.max_input_clamp) + tensor = (tensor - self.min_input_clamp) / ( + self.max_input_clamp - self.min_input_clamp + ) + return tensor + + # tensor of shape [sequence_length, height, width] + tensor = torch.cat( + [ + load_image_tensor(str(id), filepath) + for id, filepath in zip(sample_df['id'], sample_df['image_path']) + ] + ).float() + return tensor + + def _load_features(self, filepath: str, image_path: str) -> dict[str, Any]: + """Load features for the corresponding image. + + Args: + filepath: path of the feature file to load + image_path: image path for the unique image for which to retrieve features + + Returns: + features for image + """ + feature_df = pd.read_csv(filepath) + feature_df = feature_df[feature_df['file_1'] == image_path] + feature_dict = { + name: torch.tensor(feature_df[name].item()).float() + for name in self.features + } + # normalize the targets for regression + if self.task == 'regression': + for feature, mean in self.target_mean.items(): + feature_dict[feature] = ( + feature_dict[feature] - mean + ) / self.target_std[feature] + return feature_dict + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the extracted files already exist + exists = [] + path = os.path.join(self.root, self.data_root, 'image', '*', '*.h5') + if glob.glob(path): + exists.append(True) + else: + exists.append(False) + + # check if aux.csv file exists + exists.append( + os.path.exists(os.path.join(self.root, self.data_root, self.aux_file_name)) + ) + if all(exists): + return + + # Check if the tar.gz files have already been downloaded + exists = [] + for suffix in self.md5sums.keys(): + path = os.path.join(self.root, f'{self.data_root}.tar.gz{suffix}') + exists.append(os.path.exists(path)) + + if all(exists): + self._extract() + return + + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download amd extract the dataset + self._download() + self._extract() + + def _download(self) -> None: + """Download the dataset.""" + for suffix, md5 in self.md5sums.items(): + download_url( + self.url.format(suffix), self.root, md5=md5 if self.checksum else None + ) + + def _extract(self) -> None: + """Extract the dataset.""" + # Extract tarball + for suffix in self.md5sums.keys(): + with tarfile.open( + os.path.join(self.root, f'{self.data_root}.tar.gz{suffix}') + ) as tar: + tar.extractall(path=self.root) + + def plot( + self, + sample: dict[str, Any], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample return by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional suptitle to use for figure + + Returns: + a matplotlib Figure with the rendered sample + """ + image, label = sample['image'], sample['label'] + + image = percentile_normalization(image) + + showing_predictions = 'prediction' in sample + if showing_predictions: + prediction = sample['prediction'] + + fig, ax = plt.subplots(1, 1, figsize=(10, 10)) + + ax.imshow(image.permute(1, 2, 0)) + ax.axis('off') + + if show_titles: + title_dict = { + label_name: label[idx].item() + for idx, label_name in enumerate(self.targets) + } + title = f'Label: {title_dict}' + if showing_predictions: + title_dict = { + label_name: prediction[idx].item() + for idx, label_name in enumerate(self.targets) + } + title += f'\nPrediction: {title_dict}' + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig