-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathbigearthnet_dataset.py
227 lines (200 loc) · 7.88 KB
/
bigearthnet_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import json
from pathlib import Path
import numpy as np
import rasterio
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_and_extract_archive, download_url
ALL_BANDS = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12']
RGB_BANDS = ['B04', 'B03', 'B02']
BAND_STATS = {
'mean': {
'B01': 340.76769064,
'B02': 429.9430203,
'B03': 614.21682446,
'B04': 590.23569706,
'B05': 950.68368468,
'B06': 1792.46290469,
'B07': 2075.46795189,
'B08': 2218.94553375,
'B8A': 2266.46036911,
'B09': 2246.0605464,
'B11': 1594.42694882,
'B12': 1009.32729131
},
'std': {
'B01': 554.81258967,
'B02': 572.41639287,
'B03': 582.87945694,
'B04': 675.88746967,
'B05': 729.89827633,
'B06': 1096.01480586,
'B07': 1273.45393088,
'B08': 1365.45589904,
'B8A': 1356.13789355,
'B09': 1302.3292881,
'B11': 1079.19066363,
'B12': 818.86747235
}
}
LABELS = [
'Agro-forestry areas', 'Airports',
'Annual crops associated with permanent crops', 'Bare rock',
'Beaches, dunes, sands', 'Broad-leaved forest', 'Burnt areas',
'Coastal lagoons', 'Complex cultivation patterns', 'Coniferous forest',
'Construction sites', 'Continuous urban fabric',
'Discontinuous urban fabric', 'Dump sites', 'Estuaries',
'Fruit trees and berry plantations', 'Green urban areas',
'Industrial or commercial units', 'Inland marshes', 'Intertidal flats',
'Land principally occupied by agriculture, with significant areas of '
'natural vegetation', 'Mineral extraction sites', 'Mixed forest',
'Moors and heathland', 'Natural grassland', 'Non-irrigated arable land',
'Olive groves', 'Pastures', 'Peatbogs', 'Permanently irrigated land',
'Port areas', 'Rice fields', 'Road and rail networks and associated land',
'Salines', 'Salt marshes', 'Sclerophyllous vegetation', 'Sea and ocean',
'Sparsely vegetated areas', 'Sport and leisure facilities',
'Transitional woodland/shrub', 'Vineyards', 'Water bodies', 'Water courses'
]
NEW_LABELS = [
'Urban fabric',
'Industrial or commercial units',
'Arable land',
'Permanent crops',
'Pastures',
'Complex cultivation patterns',
'Land principally occupied by agriculture, with significant areas of natural vegetation',
'Agro-forestry areas',
'Broad-leaved forest',
'Coniferous forest',
'Mixed forest',
'Natural grassland and sparsely vegetated areas',
'Moors, heathland and sclerophyllous vegetation',
'Transitional woodland/shrub',
'Beaches, dunes, sands',
'Inland wetlands',
'Coastal wetlands',
'Inland waters',
'Marine waters'
]
GROUP_LABELS = {
'Continuous urban fabric': 'Urban fabric',
'Discontinuous urban fabric': 'Urban fabric',
'Non-irrigated arable land': 'Arable land',
'Permanently irrigated land': 'Arable land',
'Rice fields': 'Arable land',
'Vineyards': 'Permanent crops',
'Fruit trees and berry plantations': 'Permanent crops',
'Olive groves': 'Permanent crops',
'Annual crops associated with permanent crops': 'Permanent crops',
'Natural grassland': 'Natural grassland and sparsely vegetated areas',
'Sparsely vegetated areas': 'Natural grassland and sparsely vegetated areas',
'Moors and heathland': 'Moors, heathland and sclerophyllous vegetation',
'Sclerophyllous vegetation': 'Moors, heathland and sclerophyllous vegetation',
'Inland marshes': 'Inland wetlands',
'Peatbogs': 'Inland wetlands',
'Salt marshes': 'Coastal wetlands',
'Salines': 'Coastal wetlands',
'Water bodies': 'Inland waters',
'Water courses': 'Inland waters',
'Coastal lagoons': 'Marine waters',
'Estuaries': 'Marine waters',
'Sea and ocean': 'Marine waters'
}
def normalize(img, mean, std):
min_value = mean - 2 * std
max_value = mean + 2 * std
img = (img - min_value) / (max_value - min_value) * 255.0
img = np.clip(img, 0, 255).astype(np.uint8)
return img
class Bigearthnet(Dataset):
url = 'http://bigearth.net/downloads/BigEarthNet-v1.0.tar.gz'
subdir = 'BigEarthNet-v1.0'
list_file = {
'train': 'https://storage.googleapis.com/remote_sensing_representations/bigearthnet-train.txt',
'val': 'https://storage.googleapis.com/remote_sensing_representations/bigearthnet-val.txt',
'test': 'https://storage.googleapis.com/remote_sensing_representations/bigearthnet-test.txt'
}
bad_patches = [
'http://bigearth.net/static/documents/patches_with_seasonal_snow.csv',
'http://bigearth.net/static/documents/patches_with_cloud_and_shadow.csv'
]
def __init__(self, root, split, bands=None, transform=None, target_transform=None, download=False, use_new_labels=True):
self.root = Path(root)
self.split = split
self.bands = bands if bands is not None else RGB_BANDS
self.transform = transform
self.target_transform = target_transform
self.use_new_labels = use_new_labels
if download:
download_and_extract_archive(self.url, self.root)
download_url(self.list_file[self.split], self.root, f'{self.split}.txt')
for url in self.bad_patches:
download_url(url, self.root)
bad_patches = set()
for url in self.bad_patches:
filename = Path(url).name
with open(self.root / filename) as f:
bad_patches.update(f.read().splitlines())
self.samples = []
with open(self.root / f'{self.split}.txt') as f:
for patch_id in f.read().splitlines():
if patch_id not in bad_patches:
self.samples.append(self.root / self.subdir / patch_id)
def __getitem__(self, index):
path = self.samples[index]
patch_id = path.name
channels = []
for b in self.bands:
ch = rasterio.open(path / f'{patch_id}_{b}.tif').read(1)
ch = normalize(ch, mean=BAND_STATS['mean'][b], std=BAND_STATS['std'][b])
channels.append(ch)
img = np.dstack(channels)
img = Image.fromarray(img)
with open(path / f'{patch_id}_labels_metadata.json', 'r') as f:
labels = json.load(f)['labels']
if self.use_new_labels:
target = self.get_multihot_new(labels)
else:
target = self.get_multihot_old(labels)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.samples)
@staticmethod
def get_multihot_old(labels):
target = np.zeros((len(LABELS),), dtype=np.float32)
for label in labels:
target[LABELS.index(label)] = 1
return target
@staticmethod
def get_multihot_new(labels):
target = np.zeros((len(NEW_LABELS),), dtype=np.float32)
for label in labels:
if label in GROUP_LABELS:
target[NEW_LABELS.index(GROUP_LABELS[label])] = 1
elif label not in set(NEW_LABELS):
continue
else:
target[NEW_LABELS.index(label)] = 1
return target
if __name__ == '__main__':
import os
import argparse
from utils.data import make_lmdb
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str)
parser.add_argument('--save_dir', type=str)
args = parser.parse_args()
train_dataset = Bigearthnet(
root=args.data_dir,
split='train'
)
make_lmdb(train_dataset, lmdb_file=os.path.join(args.save_dir, 'train.lmdb'))
val_dataset = Bigearthnet(
root=args.data_dir,
split='val'
)
make_lmdb(val_dataset, lmdb_file=os.path.join(args.save_dir, 'val.lmdb'))