-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfit_main.py
executable file
·227 lines (181 loc) · 8.47 KB
/
fit_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
#!/usr/bin/env python
# coding: utf-8
#Import all the dependencies
#This disables python on GPU
#import os
#os.environ["CUDA_VISIBLE_DEVICES"]="-1"
from sklearn.utils import class_weight
from keras.callbacks import EarlyStopping, CSVLogger, ModelCheckpoint
from keras.optimizers import SGD, RMSprop, adam
from keras.utils import np_utils
from keras import backend as K
import numpy as np
from sklearn.model_selection import train_test_split
from scipy import ndarray
import time
import sys
from s2_preprocessor import *
from s2_model import *
from plotter import *
version = str(sys.argv[1])
version_start = str(sys.argv[2])
#Because fit_generator needs different data preprocessing functions, then we define functions for windowing in this script
def input_windows_preprocessing(preprocessor_X_output, preprocessor_Y_output, s2_preprocessor):
nb_tile_pixels = s2_preprocessor.tile_dimension*s2_preprocessor.tile_dimension
dim = (s2_preprocessor.window_dimension,s2_preprocessor.window_dimension,s2_preprocessor.nb_images)
input_data = preprocessor_X_output.astype('float32')
input_labels = np.reshape(preprocessor_Y_output,(nb_tile_pixels,s2_preprocessor.nb_classes))
#Get Region of Interest mask from loaded array
ROI_mask = input_data[:,:,0,5]
X_2D_nowindows = input_data[:,:,:,0:5]
reshaped_ROI_mask = np.reshape(ROI_mask,(nb_tile_pixels))
valid_pixels_count = np.count_nonzero(reshaped_ROI_mask)
X = np.zeros((0,s2_preprocessor.nb_bands,*dim))
Y = np.zeros((0,s2_preprocessor.nb_classes))
X = np.concatenate((X,np.zeros((valid_pixels_count, s2_preprocessor.nb_bands, *dim))),axis=0)
Y = np.concatenate((Y,np.zeros((valid_pixels_count, s2_preprocessor.nb_classes))))
for j in range(s2_preprocessor.nb_images):
for i in range(s2_preprocessor.nb_bands):
padded_overpad = skimage.util.pad(X_2D_nowindows[:s2_preprocessor.tile_dimension,:,i,j],4,'reflect')
padded = padded_overpad[:-1,:-1].copy() #Copy is made so that next view_as_windows wouldn't throw warning about being unable to provide views. Without copy() interestingly enough, it doesn't take extra RAM, just throws warnings.
windows = skimage.util.view_as_windows(padded,(s2_preprocessor.window_dimension,s2_preprocessor.window_dimension))
reshaped_windows = np.reshape(windows,(nb_tile_pixels,s2_preprocessor.window_dimension,s2_preprocessor.window_dimension))
k=0
l=0
for mask_element in reshaped_ROI_mask:
if(mask_element==True):
X[k,i,:,:,j] = reshaped_windows[l]
Y[k] = input_labels[l]
k+=1
l+=1
return X,Y
s2_preprocessor_params = {'input_dimension':5120, #5120
'label_dir':'./Label_tifs/',
'data_dir':'./Data/',
'input_data_dir':'./Big_tile_data/',
'region_of_interest_shapefile':'./ROI/ROI.shp',
'window_dimension':8,
'tile_dimension':512,
'nb_images':5,
'nb_bands':22,
'nb_steps':8, #This is unused!! #nb_steps defines how many parts the tile will be split into for training
'rotation_augmentation':0,
'flipping_augmentation':0
}
s2_preprocessor = s2_preprocessor(**s2_preprocessor_params)
class_weights = np.load("class_weights.npy")
optimizer_params = {
'lr':0.001,
}#'clipvalue':0.5,
#Callback for CTRL+Z to stop training
stop_cb = SignalStopping()
filepath="best_model.h5"
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
early_stopping_params = {
'monitor':'val_loss',
'min_delta':0.,
'patience':5,
'verbose':1,
#'mode':'auto'
}
s2_model_params = {
's2_preprocessor' : s2_preprocessor,
'batch_size' : 512,
'nb_epochs' : 1000,
'nb_filters' : [32, 32, 64],
'max_pool_size' : [2,2,1],
'conv_kernel_size' : [3,3,3],
'optimizer' : SGD(**optimizer_params),
'loss_function' : 'categorical_crossentropy',
'metrics' : ['mse', 'accuracy'],
'version' : version_start,
'cb_list' : [EarlyStopping(**early_stopping_params),stop_cb,checkpoint]
}
s2_model = s2_model(**s2_model_params)
selected_tile = "1-1" #Put as "xx" or "0-0"
list_of_zero_tiles = []
select_mode = 1
predict_mode = 0
current_tile = [9,3] #NB!! 1st element is tile in y-dimension 2nd element is tile in x-dimension
plotter = plotter(s2_preprocessor, cmap='tab10')
label_map = s2_preprocessor.construct_label_map(selected_tile)
#plotter.plot_labels(labels)
#plotter.plot_tile(label_map_tiles, tile_location)
label_map_tiles = s2_preprocessor.tile_label_map(label_map)
if (os.path.exists('current.h5')):
s2_model.load("current.h5")
##Setting validation set. Label location can be different. Doesn't work with data augmentation!!!
#labels_location=[1,3]
#val_data = s2_preprocessor.construct_input_data(labels_location, 0)
#val_labels = s2_preprocessor.construct_labels(label_map_tiles, labels_location, 0)
#val_input_data = val_data.astype('float32')
#del val_data
for a in range(int(round(s2_preprocessor.input_dimension/s2_preprocessor.tile_dimension))):
for b in range(int(round(s2_preprocessor.input_dimension/s2_preprocessor.tile_dimension))):
for augmentation_nr in range(s2_preprocessor.nb_augmentations):
tile_location=[a,b]
if(tile_location in list_of_zero_tiles):
continue
if(select_mode==1):
if(a!=current_tile[0]):
continue
if(b!=current_tile[1]):
continue
data = s2_preprocessor.construct_input_data(tile_location, selected_tile)
labels = s2_preprocessor.construct_labels(label_map_tiles, tile_location)
print('Labels size: '+str(sys.getsizeof(labels)))
labels_unique = np.unique(labels)
labels_size = labels.size
zero_percentage = (labels_size - np.count_nonzero(labels)) / labels_size
print("Zero percentage is:"+str(zero_percentage))
#if(zero_percentage>0.9):
# list_of_zero_tiles.append(tile_location)
# print(list_of_zero_tiles)
# continue
#plotter.plot_tile(label_map_tiles,tile_location)
#plotter.plot_labels(labels)
#plotter.plot_labels(val_labels)
#Use lower accuracy, to use 2x less RAM
input_data = data.astype('float32')
del data
#Convert to one-hot notation matrices
one_hot_labels = np_utils.to_categorical(labels, num_classes=s2_preprocessor.nb_classes)
del labels
X, Y = input_windows_preprocessing(input_data, one_hot_labels, s2_preprocessor)
#plotter.plot_input_vs_labels_v2(Y,X)
if(predict_mode==1):
y_predictions = s2_model.predict(input_data)
plotter.plot_model_prediction(y_predictions, tile_location, label_map_tiles)
input("Press Enter to continue...")
#Splitting data to train, val sets:
X_train, X_val, Y_train, Y_val = train_test_split(X, Y, test_size=0.5, random_state=4)
start_time = time.time()
hist = s2_model.fit(X_train=X_train, Y_train=Y_train, X_val=X_val, Y_val=Y_val)
time_elapsed = time.time() - start_time
s2_model.save("current.h5")
s2_model.save("Models/"+version+".h5")
train_loss=hist.history['loss']
epochs_done=len(train_loss)
del s2_model_params['s2_preprocessor']
del s2_model_params['optimizer']
del s2_model_params['cb_list']
metadata_dict = {
'Epochs_done' : epochs_done,
'Starting_version': version_start,
'Version': version,
'Fit': "Trained using keras.fit() function",
'Time_elapsed': time_elapsed,
'big_tile': selected_tile,
'small_tile': current_tile,
's2_model_params': s2_model_params,
'optimizer': optimizer_params,
'early_stopping': early_stopping_params,
's2_preprocessor_params': s2_preprocessor_params,
}
train_loss=hist.history['loss']
val_loss=hist.history['val_loss']
train_acc=hist.history['acc']
val_acc=hist.history['val_acc']
npy_save_list = [train_loss, train_acc, val_loss, val_acc]
np.save('Models/npy_save_list'+version+'.npy', npy_save_list)
np.save('Models/metadata'+version+'.npy', metadata_dict)