From 8dc01f8350f1cf23a0b0436a265e177b73599581 Mon Sep 17 00:00:00 2001 From: Nils Lehmann <35272119+nilsleh@users.noreply.github.com> Date: Sat, 9 Jul 2022 23:10:29 +0200 Subject: [PATCH] Add ReforesTree dataset (#582) * add ReforesTree dataset * fix failing test * suggested changes * Update download URL * Change zipfile name * Minor fixes * Remove f-string * Fix dtype, remove unnecessary conversion Co-authored-by: Caleb Robinson Co-authored-by: Adam J. Stewart --- docs/api/datasets.rst | 5 + docs/api/non_geo_datasets.csv | 1 + tests/data/reforestree/data.py | 75 +++++ tests/data/reforestree/reforesTree.zip | Bin 0 -> 7296 bytes .../reforesTree/mapping/final_dataset.csv | 5 + .../tiles/Site1/Site1_RGB_0_0_0_4000_4000.png | Bin 0 -> 3172 bytes .../tiles/Site2/Site2_RGB_0_0_0_4000_4000.png | Bin 0 -> 3172 bytes tests/datasets/test_reforestree.py | 104 +++++++ torchgeo/datasets/__init__.py | 2 + torchgeo/datasets/reforestree.py | 291 ++++++++++++++++++ 10 files changed, 483 insertions(+) create mode 100644 tests/data/reforestree/data.py create mode 100644 tests/data/reforestree/reforesTree.zip create mode 100644 tests/data/reforestree/reforesTree/mapping/final_dataset.csv create mode 100644 tests/data/reforestree/reforesTree/tiles/Site1/Site1_RGB_0_0_0_4000_4000.png create mode 100644 tests/data/reforestree/reforesTree/tiles/Site2/Site2_RGB_0_0_0_4000_4000.png create mode 100644 tests/datasets/test_reforestree.py create mode 100644 torchgeo/datasets/reforestree.py diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b6f51307e5c..50741377e8d 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -244,6 +244,11 @@ Potsdam .. autoclass:: Potsdam2D +ReforesTree +^^^^^^^^^^^ + +.. autoclass:: ReforesTree + RESISC45 ^^^^^^^^ diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv index 3ebfc26668e..2e37e18e21b 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/non_geo_datasets.csv @@ -20,6 +20,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands `OSCD`_,CD,Sentinel-2,24,2,"40--1,180",60,MSI `PatternNet`_,C,Google Earth,"30,400",38,256x256,0.06--5,RGB `Potsdam`_,S,Aerial,38,6,"6,000x6,000",0.05,MSI +`ReforesTree`_,"OD, R",Aerial,100,"4,000x4,000",0.02,RGB `RESISC45`_,C,Google Earth,"31,500",45,256x256,0.2--30,RGB `Seasonal Contrast`_,T,Sentinel-2,100K--1M,-,264x264,10,MSI `SEN12MS`_,S,"Sentinel-1/2, MODIS","180,662",33,256x256,10,"SAR, MSI" diff --git a/tests/data/reforestree/data.py b/tests/data/reforestree/data.py new file mode 100644 index 00000000000..27573cb6191 --- /dev/null +++ b/tests/data/reforestree/data.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import csv +import hashlib +import os +import shutil +from typing import List + +import numpy as np +from PIL import Image + +SIZE = 32 + +np.random.seed(0) + +PATHS = { + "images": [ + "tiles/Site1/Site1_RGB_0_0_0_4000_4000.png", + "tiles/Site2/Site2_RGB_0_0_0_4000_4000.png", + ], + "annotation": "mapping/final_dataset.csv", +} + + +def create_annotation(path: str, img_paths: List[str]) -> None: + cols = ["img_path", "xmin", "ymin", "xmax", "ymax", "group", "AGB"] + data = [] + for img_path in img_paths: + data.append( + [os.path.basename(img_path), 0, 0, SIZE / 2, SIZE / 2, "banana", 6.75] + ) + data.append( + [os.path.basename(img_path), SIZE / 2, SIZE / 2, SIZE, SIZE, "cacao", 6.75] + ) + + with open(path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(cols) + writer.writerows(data) + + +def create_img(path: str) -> None: + Z = np.random.rand(SIZE, SIZE, 3) * 255 + img = Image.fromarray(Z.astype("uint8")).convert("RGB") + img.save(path) + + +if __name__ == "__main__": + data_root = "reforesTree" + + # remove old data + if os.path.isdir(data_root): + shutil.rmtree(data_root) + + # create imagery + for path in PATHS["images"]: + os.makedirs(os.path.join(data_root, os.path.dirname(path)), exist_ok=True) + create_img(os.path.join(data_root, path)) + + # create annotations + os.makedirs( + os.path.join(data_root, os.path.dirname(PATHS["annotation"])), exist_ok=True + ) + create_annotation(os.path.join(data_root, PATHS["annotation"]), PATHS["images"]) + + # compress data + shutil.make_archive(data_root, "zip", data_root) + + # Compute checksums + with open(data_root + ".zip", "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"{data_root}: {md5}") diff --git a/tests/data/reforestree/reforesTree.zip b/tests/data/reforestree/reforesTree.zip new file mode 100644 index 0000000000000000000000000000000000000000..d1081a06a563214b15cb3fcece56da20882a1b07 GIT binary patch literal 7296 zcmb_>WmH_tx@~vUxHJTJx8P22cZURbcMT8-3GPlH5Zr?VcXw-?5TtPn4#9)IWS_V9 zx#!%o@3=p1tx+<@nse4yYyFxv)fC~tIDp^HJaSL#kH!BVkbcdMrY(Pr*Wn^0^8%|zO?pEJc-&>kYnJv8teQFUV;NKUI_E+(mwjP!oe;wh! ze~f=ca{h(<$6>QB-CS*dK`efC#(&fNL)lnER>GM5kB6I`o!$8N!0O^;4YWW@KaWsT zkwr%#Lit5RmzR^$_%&1f?nt0t<2}JqJ^+BmBrhed<(&bfKwCWrMm3~8HK1;Lsh{m6 zuu(^c682Qh)552YYqA;N&R+veA<03mVrstXsQ?(W#;IYg`2PVVMI260K$luN7Y7`l)XLl{! z)F$`)PSRvc9v&5&tYJv2_b*h^c*x*lE{QY(pIoGPN#VAI=$SkvDmmQZq$tpenL?=O z%-0whH4**W5w)m%s)P@su+@;2K9dbj_yX8Y+x%XLVK}K)o3J2YeQ^)py5}pAgzr{2 zO_|AAk%qFJRr3C55#tR;;C3+flu2G94WewYh-amFt9R1dv4b0(e5oKz2tFzul1;|+ z{ki(;Qt<2)nzN!Xp9!Z!d;13PF5i%9iaqcJ0qo}Td%qTrwIC&WGpLUi2d_-TJ{Hz2 zMTwy#;Zad7XTmy+Q9UpqVVo>&#Q7Q(iXE(ybWwop7#RLVhQi{UZ70%W&7;Y3r?N(~ zD^HjKdf}0>=y8?^nEF1!kYfpP5T8f7(9;96SsfDoSa#PG%s_&3#1i8dql6>Pyt%hw z%=5d)9F!OSEK(BW5F=k_t3PvqDyTDYcfF%{FG_qSSc%cHCT(KeA;jdSJS*gXn!9|` zS#$N^4iY!lmI&a{@UE{F$&W-2$HVZUi*q~V3~t_>&t~J;AV@C_AP~}*YTSjQ77F4i zNX~Zz9bi2M)bX!$?er^;q(72b%}5H&Eq8lf@m(T`(wJU{CzMs9+;I}#Dfl>EE`ab- z75xw~AZ)F5ecnv|>^XcC81K*tVFGhD(_lR$)~&&+Z@Ilt2QM@eBYABd;;jweHvPtx zMGDwgA~;gd2Zqgk`IME(?Kt}n<9=|x!pyKcAL7H4B)uf(Zk0$;cPfwvn@MNFTDexk zvxjoD+Uqv?5An#%C$tve>bZgb&e4VXwuu(PCM-U-(H?HxiN>2j6lqRKOKykeT#LOJ z{qBj=N9x|PX^qhCdo$|hSFZ0~TTkws(NUp!Kn5P4t9(2I={3Pp-{nngqA_=pAWc)V z(V@GbZPlT!+Yox-gW;L1rXy+9Iu=Kuv>EOyoPMWH40nwxaoQ)uS zwuq);Zeq_4JbMjupImuoxt7C$%~N^gha82?}S=l3%FB3hW-Ut;fW@PMpnk zf_t~(&TRz{kQzN#n56VFIUx{m zZ7>vLah^(MH(0eN(fvB{{3lF*BhL}J&wr?6L}@h2yBwZK;g(tpT}%DPoGHT0XSxT% zWku(a@{=rr^}c#^%jZ)y>sS2`8OU(jQ%jeasJqIsw^k)qy`e(V;$gLCfT}?zW7|qa zWxeqvY;2?6%DP_psK~YpnUT*){in#ZUARH~(lv-HuKp1P+Jx%QW7IFMvNG>BdQ}a% z-rnz$A#8_pl@*XnKzJxQD$oe2)Re?KAHz>AiER}9ieDeE*8<;CTJq7?b2|iXQm*3R zjI?A*y&QDCrvJ9WAK>&9GJW%Dy#gs~T3GEMnRnad@B@|I%m+DOh-SGG6&&(Q=p!0| zn@YAC^gaJnY$AvR{=3vt{0D+9J=>Ac{^*#Co1i&B+jpqy`xp__ngJV zh0_&wrWPq{hv=*%96A9xEgSbGa|0@?AQeI%I%@0i^m@{g6?c+H+{oalXbuLl;k`tTqtfcO$)lKh zADoe&1eg)EOgJc%u{-v6iXbYVZ_6V`O-d@ZO1a}}u9bN0tFXt8v;9s|OE(Sw41*@` zzMIGKg=g}FH4U%`s*c`wFD62{hh+d==Ga?^U#>4icvu#o4xGkWYMkP(HRQa(CLVPbLk2d*8skhcxzRHWoazx$GrQ@ zX!=1i299Hb`$k#_KRkIK99mhfOz4yZ)!WD(2o`KeYZk$ka!4xyHYZNfnAt-EvZhzq zA}JDWU;Qp{=JMWw;h~zUgCTXQ>~YT@p}hXiHOH%{b71H^Ifop{WKvSCMot}m4?Or<7CTR8s|5DE z*BH;V7gWHnb$2=?(lka5N&Ar2QbhYMjb&f;;&z1zafnl;GQ-zLu@HeW=R4jzjP;uL zn(*y12GlIf>42|Nw41wC&1hyI%%}j7&_t@17vbz^Zv{cM7E2Lw^@+`m%fk9iNENs0 zi>I9;4@T@6)#{VZ?Kl8I%77uKF|7N`iTeh$KJ2^=GKvqib;<-hq?-1}avGP+o2t!` zPGUKE7?RdKz&thl^Uc8Z+1E?ziN1R{_B4YacA%W@)$Q~U&9{cvc$nUDz1wGMDZzD| z(Pkhby!~b3DQCrviSdQc?xOU;(4$UuSzFdiOQ*I5M!<>#r)o*5Cq8b-u8Xg`CrYEl z*RK(;@I4le54$wJXABrMjBw&W+jq&3O;CFOlX1%awk`wK` zs--B`{240p2AOjya-{j$j?Jt|%lESeO0Hzz&#R2;NR1(B&jR$+hu_Li1kKZ7%gG-w zTr8}hp5x}KaV>FG9+v}ao7gWY(Flhc140d+((G-$DCN~gqPz1acQd1;y>WZl$!^fo zEhJyErp+1fS3wk}op+f6P8XK3-kYpZF*hPHn+atV&I}x~t0Of;RPE6i$;k^zy4(i0 z;+joU z0$70{v99+1hL<0k2Fy17JKoz+o=p8OuYhj_IT)ba*7Gtsbw9=&al=xUx~S={&w|Yf z(SaR$PG3nioUc1h;a zzSg{3V~O6;lE4KMz@K?hjD0C#>eGi*?M9S6ht&8qXvOu<<%aWIdigep|nzb*9ETYLL-wA#k(X4U2?9&4mbCMrDN+~_} zSgM&5XxW#ovQoI#AvSmGJexdv)2M=FSS^acuBl#w8H(DO~DJV7uq^=b@ zQCx%wkvD}oG#SXHSt_C70&8haeHtcc?errz$h*|s6B4s zVoO7u5*92sEid^A>a)=BVMQD8%B9Y&7$0Fh{mG1e)|>9|5LvLa2MM&MK*PSK{;ndwXre29Qe(Rl>8i(8LqO#=7owrR=r6` zzh;!K6JF6_Mf*4{z1Iy%jOTjkfkww@(lNu>S+;pijss=nSUXYw2KyABsFW$fVq2Sa zRY1dt&z?HF2|8CfIF5y)tCw=6JwdtHFrWlt1NNvf*W*GORc+t2(wX;OpuSo>eSBqu z=YunsDxHy0`~D4o+%)+B@afJOtB4dH6LoFWJnK15;lpOko?)J4z}JLEJ?kRo1mf~X zDgq$eIqLkwKAC@b1G`)f{FS)AQ4T~75yqx|v!-Nayp83H05CA=%piB+n=)@ zm=Z+dnC=#`FA@S0-fq=z#B#4z8cR4ObP!V^aOx43lucNLHO@G|ANkcXZ6<*#WrDAV zO|pVE7Tn1*TPDH^J@gcxJ8aQ?8!qxrf}-w?*L6_dx~eX(_f7`iv-L<|E^cHx_`lYkgQv zWA?}b_ZSG*ApTmqBwn{=geO9%FL<7&P$DlGl5y{pYuD3ikzWk?M&g$ zJTWriT-&eG$#bO7Ek5zV#zb~eIz47RF<(5qWEexmR788I3Lg8+%)k87%^bZpqFlzB z!@lUWB*c#B4gH6fABrH}z_hA)uhJvMxv4acEq~$x)Mv*xNxcH+fIke@a7rYsI zj|ehPFQ(lkEV>JaEfA_KdU?r_nFGjs2q5Tkul7c-o$1`#h93_;I8lxOv>#gy557Lu z0+$TP`ZsYwc_O_2yrY#gs&_MyAMFmIav6#R)6F87xyN>_x*2O?)4gHTV$7Z~#0a`C z6T8oH;drle!n_(Vjzr|1tyYa+T?t1b%gaNzQI@-R)LPF?PADn}-#ufeb!99#$rKp4 z1}gL7XeiJ-2M17z%3h@lj48tHOM+)|v+0m_$!+?Z260EJ#WliCM9Y`5%UVsQsiH`> zwkB%m?*iU`9ljnMec5aaH=OjWa+Fjn^#$qH>Dgz>kmMip-3=_?qa2bilWu?%P0A#c zpGbzJT}5e^pK7>ktjfeyZ%bL)GKRdplT*>*xEqCAS6(n6lY=)F*WC{)a!PK(o{!gD zIgl*+tbL^JYCkgy8#&;`+1jJuN)QnMap=?OdU9J~y$m6d_UGmRVa}QQe0L1@C$4C? z75{*DWqE?sgHxpHXIJhm)@oW9Dol>DL}L_lN3`$Vv6jrrSit%ry&3iBNsu&(@Fc@reW-_! zj;26k>1?^wDf@&K8THU$s1|<>l9_?L1Z}0l=5FxqbY@f-Xo7;r`fzol63?>zV<%lqGUQf*K#Rf)jN&7uwHo4bu6)VM?8cix za%&>?#qDuC3|&;#Och(E9O|&S~w;%m9(c`#VozKQ2W+ zZvKscph!a8n}-aSNV#UXxI_EUFKvRQq-2TFKy76NF>YNO29DAL!s&EmzkyZ1w;OD5 zpFZrMgS)Qp*U79oK?p>71gj2IPUiWfJwgX_9&xTE{Z*M|(4QW|o;LqRSF&c9LECC5Ko9_g|HEi|yb}uQ*q%!;0pHAeZ z+N?L9xIn>y@-?wbZKNGYuzTytcva$4N`WG;pw>eRZYx-#@t3#_HJrvi@0lcTdWS zyI2J9_)V;r%CFc9QeEPOaV>MSK0hh3y$msBY)s`B<$s4J-`uyvY%u$dQc`XXrwG4tB@fr;|ct4=am4`P!aQ;(U8p9ehE*5AbwIMM&uKp?-e4U3Zn7#dFC- zvjg_hJ;lw;TkA7&y31aUmpgE=DBNI=GZ|Gcd;(~<)IE{_Iufe7{Pg-vvNjn&;?VdX zQIfsIQ>E@Q7%HfaFKYR5W+9zN#If$sEE$0B<_~PJ|h;+WX1~l-MitxZW%>~p`ja9YNg6`?PK8ke_*ldSXuq;MrB)p7#iu844wSqAB4%Qs;<0}(;wLWHXI z;Ue#Gp)MhDLb@cUOxmEJQp{RnX+htcTA>}3tn7C28^XoS5$P3`rulx3@ z3PDifqVEgVC@bwU5OZIwznEl%yV^AJhlG-XQPSupr0`?f$Vms&`H~X(*bPMy2zh8i z*KJ@?W{}IRvBJL;O`Z>j8@3hZ>v?g)M}V$xC%Qk|?z{10Lt+V_OWreexTF(dr+F?GS~%{?C?Z5VGF5Rq9|YU+SLi7ft!L>;f4 z@j{49*PQ}<%Mo-)7DH+q!5l0-JOwEC1_U+lNKphR-J!5|nZ-?%rhUOA_+aJV@CeMQ9%>Y06b`rXx_`<9(euXvPV4wZh+gIFaPb!&c~3%!=U0i>}(=hFqt7 zG_M>AJQ09zj8Fm{_R>v}f5*o@n1(9udEJ!0v0H!?)83k^;Jv(Tq&oai96xW2hvDt^vRI#4|!$@eqWP>MQBm7$hyH-d9( z_{pu#(BxrKFrpgyU;PRYgah~Y!PIXo0DuX25dS^@Gp71`*6*O|AMH1b?$`3q*y>Nt ze-5txWYq;I{969G(|?Px)D-`e`Oo3eKeyieYwJHlr~fSz{a2RukD~uDX}?0KKbikI zhWcmeeE-7a_`~FJ{GG}BciDe;i~p4Uj{dI&{g2*JO%eW&e}D);1kC)Z=OfW?q5lW# C73+Hd literal 0 HcmV?d00001 diff --git a/tests/data/reforestree/reforesTree/mapping/final_dataset.csv b/tests/data/reforestree/reforesTree/mapping/final_dataset.csv new file mode 100644 index 00000000000..9c71d73563a --- /dev/null +++ b/tests/data/reforestree/reforesTree/mapping/final_dataset.csv @@ -0,0 +1,5 @@ +img_path,xmin,ymin,xmax,ymax,group,AGB +Site1_RGB_0_0_0_4000_4000.png,0,0,16.0,16.0,banana,6.75 +Site1_RGB_0_0_0_4000_4000.png,16.0,16.0,32,32,cacao,6.75 +Site2_RGB_0_0_0_4000_4000.png,0,0,16.0,16.0,banana,6.75 +Site2_RGB_0_0_0_4000_4000.png,16.0,16.0,32,32,cacao,6.75 diff --git a/tests/data/reforestree/reforesTree/tiles/Site1/Site1_RGB_0_0_0_4000_4000.png b/tests/data/reforestree/reforesTree/tiles/Site1/Site1_RGB_0_0_0_4000_4000.png new file mode 100644 index 0000000000000000000000000000000000000000..95e37237fe8fb22ab0660b1fd3499dad3dc5f2ca GIT binary patch literal 3172 zcmV-q44dmX07OOsMj@!9JAr4GJbgr>TU@plHvBt(iUzJ5(GIa~f6NFBa^rQ)Gf{3j~ zAY}6|=8AOSbhu^aud-9XqCF_f^mLZmbn=n_(dW-7pk)tiLD&fNT3Q4zX6za9-E&hr zoCpJL5kNRVAp;$q`21)oqJ8`j$4Nb?KBc(t=WqlyiP77ltQk8k4+92E18Ur16{V3z=NtWKw<&SHW&V2*JJuKV%dK>KTlUkE+ z!A^=lb0PwCFJQls;H*@(68)?Uv-iA7l#mH5?1E*8^ zYpM;vtB$jeT=_&j}KE{1k?Tg3IL=+w)-Z-x8py_*vzsX3gK5(KysEv z3&E-%>BBAuuh&0h&;pZq`DxZ+kF>L1lP8yWu0>gXt; zj3p%p{}iOqrfEuy57WQ_8I?`8?bA?E7K((tdRYP&E1DT$8iq)5HB!m!r2@Xb@pNJS#0*#}^*1fK>y^4}G0^ds0n48P36)bRB$`z{3E>)L}!YDC}00eJ1B-h+rR zo=MrW5rV=8Qn{7%B=667~#aV4P`3Pvf=i{{hSE{SJB+{5DVx{j{0V) zX2XU&ML~zN@&K^MD_?1^OHEqNlM@qO!>_c%NsEfS^hV04lg01}CBPMfI7PAt-f@75 zq*obHDveR}_MV;l;=@l}GkpB$9tP!zGpnQ@LJu_|F|G_5B2Y|0zy65uWgBQqeW!Kq z;j;mJA!RowwKQyk<00P`6w11tL{-Oe_b1lgIDu~eh0^$%;jRdt(mhb?l{V#I?35yE z)09X7g;T6dA_ECk%l-@)a!#L6%YQi0j*E#M-y{<&~8P{vo{4 zUosiKbG7Lu>S#^y{4%E-rP8iyVPuxz>&bjChJ06;Z#`(#n!W^!mi zj4>!4$?1|Y?W(c7(Cv-2dKAj}7!ZlGD-;SLj^=CoO9CQ#*4@hOxJ*uIOrg%;ajzD= z-`oDa^6I}GWpYx0oLsng#Q6Ts*%C<^vbNiaBEG|ZInRbP$wmOcV-w^Dt@rdkH7(i8 z{XRU$jQ_S`HLZ>ZjB2tIp3klqeZ>)j0nzzCY#7McY)vSRRf64w5hQT3Hq9a$7Vb@7D-%B2Ur`C`=h1s|QDCSnr}Q zjqqw>F5YZ+A*zlGm!kZgj3>t)C^7CB{I?|!I0YZX0}QL7MuyQsB7BO%4-wnQRMiG? zA;?sb+t)!Hx1GWy0%s>mB&Wq?9m30bHy_T*Y^=5WaiDjmzsAV(m7pI8V9-NQ zvPe+{=>*4hM-!r0#!(SF*!vdr@RQvGJKDf;z51Evix~3(5nJwy1^!nZ_2gou` zubg{&OQi-Opyw8T4&kzYQw6<7Tq7+kn*h{ACFAI@xeQ_g5Q~96hLR$>IfyU}d^-ZO zWZa2JwUW8F-91~l2(J23+VH;b_%yb)YOSq7j)U~?7&jiRQ9c;gQ5`|1j9fiuCwiD>^}$&kr0QX z%{Uc07rws`o;hhNNI$>!vmmTf`G!7u2{PP@%DEEeXw_i4=km58GnIR&-ziZDw}qJ! zfm%_?d`ZuP*d`N5N$V)|+1?BD6`@cw+%vE4tO2vQF!dn}8OXPRhFt%dYiV~ONl?m+ z!K2XVor^_y6~izd_zjz6Lsc%B*IYQT4@%N+=qrKn+1(L;VBjJww+Jj^Jf5Y~#_TXr z2)2o^=_FoANjyVv`h>d`V$hlv!X5%;D&NX)>jo7vU!!)`?)))Ingc}X)sHO5MERJG zJ`R`=+@Iy{Ci)Qr{M4k|q&pi~SL)(-eJ@hmX`AO3e1$3u-qC=u?4%@b8bfgFZta=^<7IQ<@-~!RT^ZBAJBd~m&#MJ zZS}-j1gA0|?fyheDs?RmFO5nZuH+sQFh(mqsI%n1^c zWdNZ==fV-k_wc(Pl1TTqo3U?S?oE5!Eh>H(HPk5MOa ze*6kd*{IT}!F@E&kR#3XSV|PB^dfs12&vsx-yJ;Vftv{+4(1}*S9l%;2Pm`dx-&1C2hlw{tYNIF`2xgbQhF{7#&@A3p{~x#F7P+u6KBEz0000< KMNUMnLSTYar2(q| literal 0 HcmV?d00001 diff --git a/tests/data/reforestree/reforesTree/tiles/Site2/Site2_RGB_0_0_0_4000_4000.png b/tests/data/reforestree/reforesTree/tiles/Site2/Site2_RGB_0_0_0_4000_4000.png new file mode 100644 index 0000000000000000000000000000000000000000..69a0632fc1a4f42a1529629f5660e3d84e7cdade GIT binary patch literal 3172 zcmV-q44dwVYCsrk%C$*^S;Hp@T$uz8W2_bElRJ>FO@442YH z>A??@kTX?u3??aHzSAkctK;_{6w4_w;O|k^+nSJ)Or1V0X|ta8q$DX9FqqZj0`pGm z?vBe1QK~bU|I0Hk4$Bxo#%l|W_s%?0u)Nk(n=I)$3s&0j{#IxfdKA}~MVy?oe|b2M z(jUeF|NCzdrX2+k3*gLSo+=cj{o{@4U7}@y)RDJZXQnKX8?62!7y&Qy3)ubY9)O6p zFi4;U_CZ@-pbtm~+b>c0;7n#;-Z09;(Z*`$TG!zLWy>epw+(K`q+(!o4yraS=| z`sB6Zjx^t|UqWt?z8fM2GFlm_tIuYKx6^C|?R~Q=fMAy3dED zbXrR)Y~&`^$)f7 z&lTjcyjw`+%LaAXw@&2qU~>&)`KQj_!iM4mBNJOe$JhxoVs~@yT=4VAM#OxXkCC1; zZnf6W&g^9QNZ+6fadgV*V`yDh2&my^3(Q zp|S8i^U3k;jIDAd(;@&mw?Td<0I@w%#ps+Gil9)_$C4zm_NU&Ti94lC-RlDL?jFGLrPtZ09uW!U+#@07fd2d4*6)%&qI#AZ zvaCK2SjYe`tVSa3m*Y;;@MPvyQvLEe)mQo0=8{V?hO50Rh;KEGMu;=zR+}H7Me}5u zH`~oV=nB7d)pig${Z%N<3lK{T{jh}1s4O_$RdQnuvx%%mXEAH0@TrAr8hIy_y75Z_ zHiDV3*mtV!OV`nvG30<7#tRberiZ!792FPLmY^l9s`NITTK)z`|2figLS(_E$z%qw zWWzQe2`qsh!Wa(?NLK00_irY0yvhFSlx`u)09XFIUF+2TvjN;(9>wDo0-`=PfHusp zBv1R(ikE8a%SfC{q|&)Q5TWjBE?J!5K+?m9BS0*6jT;78Rg%H-p#wJepoe$14(&cj z65p`S_VzuB2}w!IO%z6p`RuQ=pSgsxW^UDBSBJ!??0^EpR2oI(jto-CwoXB{1m6#0!G`Bsw-Z>}<9|Yps;dAS*psDzWD$cU@Y#iEp;Flx7z4Bg%%Ih{1 zIK9P@k{hV-$NVx-# z>}!UtygRBL9+Hd!S4{>$G+AgUF{;x7Z5l;#Zt!1w&&9%nKF1{$Im z-)tgoW1}6yJnPqVk8!ERu$`;R`E<#4_*{!9En*hXaPs4>s-gdw+whp^E&lqPJ{lk^ zTAeiGK?XC;Bx5=eIznP-=NLV*FSVm;RUxZHoyFn*?@2_wXXF1f0)&D|vX0!m9li+L z{AbXRPaE(dI;MAnyX<5%X4{fqt&ieR6t~0))nulI8In)G|EcjfRxpb!6dKV+LoAFJ zjTn30krR0U)Hf7bg@WdzH(oE-i3%VrgR^#g%0p930( z?m#H+vUZ_I1!e91p2@`-fKG|^%jRS z-7&i=@2P3^g<&bTm^eQ;ehf*wp3uDok#Ai6r$C1ozE~?M1ID>~PGjL72RDQP8nJTm z=tf}blU~g0t2)qYvV89mCownyJSMRF7J@FgE9M>i@Ug?`g)Os!N^*zu_YH;xLb3%u%KP09U%S{u}_lk+8shTEt+U<0t|_%aHMlL&K-hs&hsTP77^0vp5vh z55Mgjj&sYNMgV)`<{#S#7xrQjcMAlMHJ()=`#%pyKulWA4_X@fBzk(sMZ5;=8G0;T zqa;bK*8)JxXD8SIy)T8e0O%ZxAs=`-17iVzrmA`$HH2M4mLFuY@;5wW)bz;oPulA_XD2j^VdOGbD&de~F9!eOSEB}H7#Y=0E<|8$2 zU#z=`9NxtML={5NP@sYX_8JtGEypKEtmE5Q#)*R;JQ+{5h(3Q6%hSW9P4&@aBD#6| zA?1d-W-7r9S403pk1H&?-t9qNLH9}iS1X!HJ90JJC|ykv0X^-uTgOU`{I;~^*o-?% z4v#ZdOk=k>aIN&B?sqQ8%_A*|Iq(VC3Y4GRJbct0)C z?bP($yF>JBrmic&aOhf*miUVSlH#i7esyQm+5$1`%bu7IixNjDTU z+vJCPcJ%Kz7!6x$KkC)Jd*bopLfimZAHB}|dO#3qlYS{)>UDy6Vh=ihm3dsuvGE7J z4@&99sAUX7dTVqCp>hQ8GAqDXn*nYHH8J;1)@$3T0UQSxILNy`W7`st9sO|WC;Sf& z(Z)YMc<`Up-6E6w7;(b_qyMV%oMzg3GCWyNJQzBGEg$6X;zbH`wf*k~BOZ3Ak4$kq z@!{#T=s|68+9|LnG2VvG5iORsiJi2nVQd%B2|=11Q5L-^^gJF}b07rzF@x+LjU8!; zEQCeL(EzM?0E4;aL?D4o`wB3?Vvo@J@Ny9%o|mPQ4xa<~qM4J61q z7bGOtmdZh&<%r%l_7uW19ga{qZ~UQ0`$xA!#iWvTJ)7PKy5OBbS-xEY8c1YZ%{XY! z$^kT7C*mYsyiA)0Q!zR%m z=NA*}(#uQfclfG#YS|B`CHiAd7FG4PBJ7l=ZJHQsS=(9w7m~HcQyo^*Xz`%s51erG zK!`nDviyBZ;p641LGv&j#<9;sH~E-dL?fP0D6=N$i8AEL|8lflVExcLiLnU)0000< KMNUMnLSTaRK@SiB literal 0 HcmV?d00001 diff --git a/tests/datasets/test_reforestree.py b/tests/datasets/test_reforestree.py new file mode 100644 index 00000000000..1337cfb18c3 --- /dev/null +++ b/tests/datasets/test_reforestree.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import builtins +import os +import shutil +from pathlib import Path +from typing import Any + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.monkeypatch import MonkeyPatch + +import torchgeo.datasets.utils +from torchgeo.datasets import ReforesTree + + +def download_url(url: str, root: str, *args: str) -> None: + shutil.copy(url, root) + + +class TestReforesTree: + @pytest.fixture + def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ReforesTree: + monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) + data_dir = os.path.join("tests", "data", "reforestree") + + url = os.path.join(data_dir, "reforesTree.zip") + + md5 = "387e04dbbb0aa803f72bd6d774409648" + + monkeypatch.setattr(ReforesTree, "url", url) + monkeypatch.setattr(ReforesTree, "md5", md5) + root = str(tmp_path) + transforms = nn.Identity() + return ReforesTree( + root=root, transforms=transforms, download=True, checksum=True + ) + + def test_already_downloaded(self, dataset: ReforesTree) -> None: + ReforesTree(root=dataset.root, download=True) + + def test_getitem(self, dataset: ReforesTree) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["label"], torch.Tensor) + assert isinstance(x["boxes"], torch.Tensor) + assert isinstance(x["agb"], torch.Tensor) + assert x["image"].shape[0] == 3 + assert x["image"].ndim == 3 + assert len(x["boxes"]) == 2 + + @pytest.fixture + def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: + import_orig = builtins.__import__ + package = "pandas" + + def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: + if name == package: + raise ImportError() + return import_orig(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mocked_import) + + def test_mock_missing_module( + self, dataset: ReforesTree, mock_missing_module: None + ) -> None: + with pytest.raises( + ImportError, + match="pandas is not installed and is required to use this dataset", + ): + ReforesTree(root=dataset.root) + + def test_len(self, dataset: ReforesTree) -> None: + assert len(dataset) == 2 + + def test_not_extracted(self, tmp_path: Path) -> None: + url = os.path.join("tests", "data", "reforestree", "reforesTree.zip") + shutil.copy(url, tmp_path) + ReforesTree(root=str(tmp_path)) + + def test_corrupted(self, tmp_path: Path) -> None: + with open(os.path.join(tmp_path, "reforesTree.zip"), "w") as f: + f.write("bad") + with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + ReforesTree(root=str(tmp_path), checksum=True) + + def test_not_found(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found in"): + ReforesTree(str(tmp_path)) + + def test_plot(self, dataset: ReforesTree) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + + def test_plot_prediction(self, dataset: ReforesTree) -> None: + x = dataset[0].copy() + x["prediction_boxes"] = x["boxes"].clone() + dataset.plot(x, suptitle="Prediction") + plt.close() diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index e7bbb25b9b8..13196988787 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -75,6 +75,7 @@ from .oscd import OSCD from .patternnet import PatternNet from .potsdam import Potsdam2D +from .reforestree import ReforesTree from .resisc45 import RESISC45 from .seco import SeasonalContrastS2 from .sen12ms import SEN12MS @@ -167,6 +168,7 @@ "PatternNet", "Potsdam2D", "RESISC45", + "ReforesTree", "SeasonalContrastS2", "SEN12MS", "So2Sat", diff --git a/torchgeo/datasets/reforestree.py b/torchgeo/datasets/reforestree.py new file mode 100644 index 00000000000..7ab65b49bc4 --- /dev/null +++ b/torchgeo/datasets/reforestree.py @@ -0,0 +1,291 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""ReforesTree dataset.""" + +import glob +import os +from typing import Callable, Dict, List, Optional, Tuple + +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import numpy as np +import torch +from PIL import Image +from torch import Tensor + +from .geo import VisionDataset +from .utils import check_integrity, download_and_extract_archive, extract_archive + + +class ReforesTree(VisionDataset): + """ReforesTree dataset. + + The `ReforesTree `__ + dataset contains drone imagery that can be used for tree crown detection, + tree species classification and Aboveground Biomass (AGB) estimation. + + Dataset features: + + * 100 high resolution RGB drone images at 2 cm/pixel of size 4,000 x 4,000 px + * more than 4,600 tree crown box annotations + * tree crown matched with field measurements of diameter at breast height (DBH), + and computed AGB and carbon values + + Dataset format: + + * images are three-channel pngs + * annotations are csv file + + Dataset Classes: + + 0. other + 1. banana + 2. cacao + 3. citrus + 4. fruit + 5. timber + + If you use this dataset in your research, please cite the following paper: + + * https://arxiv.org/abs/2201.11192 + + .. versionadded:: 0.3 + """ + + classes = ["other", "banana", "cacao", "citrus", "fruit", "timber"] + url = "https://zenodo.org/record/6813783/files/reforesTree.zip?download=1" + + md5 = "f6a4a1d8207aeaa5fbab7b21b683a302" + zipfilename = "reforesTree.zip" + + def __init__( + self, + root: str = "data", + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new ReforesTree dataset instance. + + Args: + root: root directory where dataset can be found + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + RuntimeError: if ``download=False`` and data is not found, or checksums + don't match + """ + self.root = root + self.transforms = transforms + self.checksum = checksum + self.download = download + + self._verify() + + try: + import pandas as pd # noqa: F401 + except ImportError: + raise ImportError( + "pandas is not installed and is required to use this dataset" + ) + + self.files = self._load_files(self.root) + + self.annot_df = pd.read_csv(os.path.join(root, "mapping", "final_dataset.csv")) + + self.class2idx: Dict[str, int] = {c: i for i, c in enumerate(self.classes)} + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + filepath = self.files[index] + + image = self._load_image(filepath) + + boxes, labels, agb = self._load_target(filepath) + + sample = {"image": image, "boxes": boxes, "label": labels, "agb": agb} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.files) + + def _load_files(self, root: str) -> List[str]: + """Return the paths of the files in the dataset. + + Args: + root: root dir of dataset + + Returns: + list of dicts containing paths for each pair of image, annotation + """ + image_paths = sorted(glob.glob(os.path.join(root, "tiles", "**", "*.png"))) + + return image_paths + + def _load_image(self, path: str) -> Tensor: + """Load a single image. + + Args: + path: path to the image + + Returns: + the image + """ + with Image.open(path) as img: + array: "np.typing.NDArray[np.uint8]" = np.array(img) + tensor = torch.from_numpy(array) + # Convert from HxWxC to CxHxW + tensor = tensor.permute((2, 0, 1)) + return tensor + + def _load_target(self, filepath: str) -> Tuple[Tensor, ...]: + """Load boxes and labels for a single image. + + Args: + filepath: image tile filepath + + Returns: + dictionary containing boxes, label, and agb value + """ + tile_df = self.annot_df[self.annot_df["img_path"] == os.path.basename(filepath)] + + boxes = torch.Tensor(tile_df[["xmin", "ymin", "xmax", "ymax"]].values.tolist()) + labels = torch.Tensor( + [self.class2idx[label] for label in tile_df["group"].tolist()] + ) + agb = torch.Tensor(tile_df["AGB"].tolist()) + + return boxes, labels, agb + + def _verify(self) -> None: + """Checks the integrity of the dataset structure. + + Raises: + RuntimeError: if dataset is not found in root or is corrupted + """ + filepaths = [os.path.join(self.root, dir) for dir in ["tiles", "mapping"]] + if all([os.path.exists(filepath) for filepath in filepaths]): + return + + filepath = os.path.join(self.root, self.zipfilename) + if os.path.isfile(filepath): + if self.checksum and not check_integrity(filepath, self.md5): + raise RuntimeError("Dataset found, but corrupted.") + extract_archive(filepath) + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + f"Dataset not found in `root={self.root}` and `download=False`, " + "either specify a different `root` directory or use `download=True` " + "to automatically download the dataset." + ) + + # else download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset and extract it. + + Raises: + AssertionError: if the checksum does not match + """ + download_and_extract_archive( + self.url, + self.root, + filename=self.zipfilename, + md5=self.md5 if self.checksum else None, + ) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + image = sample["image"].permute((1, 2, 0)).numpy() + ncols = 1 + showing_predictions = "prediction_boxes" in sample + if showing_predictions: + ncols += 1 + + fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) + if not showing_predictions: + axs = [axs] + + axs[0].imshow(image) + axs[0].axis("off") + + bboxes = [ + patches.Rectangle( + (bbox[0], bbox[1]), + bbox[2] - bbox[0], + bbox[3] - bbox[1], + linewidth=1, + edgecolor="r", + facecolor="none", + ) + for bbox in sample["boxes"].numpy() + ] + for bbox in bboxes: + axs[0].add_patch(bbox) + + if show_titles: + axs[0].set_title("Ground Truth") + + if showing_predictions: + axs[1].imshow(image) + axs[1].axis("off") + + pred_bboxes = [ + patches.Rectangle( + (bbox[0], bbox[1]), + bbox[2] - bbox[0], + bbox[3] - bbox[1], + linewidth=1, + edgecolor="r", + facecolor="none", + ) + for bbox in sample["prediction_boxes"].numpy() + ] + for bbox in pred_bboxes: + axs[1].add_patch(bbox) + + if show_titles: + axs[1].set_title("Predictions") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig