-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmodel1.py
89 lines (63 loc) · 1.93 KB
/
model1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
import cv2
from keras.preprocessing import image
from keras.models import Model
from keras.layers import Input, Conv2D, Deconv2D, Activation, BatchNormalization, add
from keras.callbacks import ModelCheckpoint
from datagen import gen_data
SEED = 1
EPOCHS = 40
BATCH_SIZE = 4
LOAD_WEIGHTS = False
IMG_HEIGHT, IMG_WIDTH = 128, 128
inputs = Input((None, None, 1))
x = Conv2D(64, 9, padding='same')(inputs)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, 3, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, 3, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, 3, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, 3, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, 3, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, 3, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, 3, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
outputs = Conv2D(1, 3, padding='same', activation='sigmoid')(x)
model = Model(inputs=inputs, outputs=outputs)
model.summary()
if LOAD_WEIGHTS:
model.load_weights('model1.h5')
model.compile(loss='MSE', optimizer='Adam')
checkpointer = ModelCheckpoint(filepath='model1.h5', verbose=1)
def _train_generator():
rnd = np.random.RandomState(SEED)
while True:
yield gen_data(rnd, BATCH_SIZE)
def _val_generator():
rnd = np.random.RandomState(SEED + 1)
while True:
yield gen_data(rnd, BATCH_SIZE)
train_generator = _train_generator()
val_generator = _val_generator()
history = model.fit_generator(
train_generator,
steps_per_epoch=512 // BATCH_SIZE,
epochs=EPOCHS,
validation_data=val_generator,
validation_steps=32 // BATCH_SIZE,
callbacks=[checkpointer]
)
model.save('model1_final.h5')