-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathresnet_builder.py
30 lines (24 loc) · 1.01 KB
/
resnet_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from keras.models import Model
from keras.layers import Input, Activation, Dense, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D, AveragePooling2D
from keras.layers.merge import add
from keras.layers.normalization import BatchNormalization
from layers import *
from keras.regularizers import l2
from keras import backend as k
from resnet50 import ResNet50
def definenetFeat(input_shape, **kwargs):
resnet50 = get_ResNet50(input_shape, **kwargs)
return resnet50
def get_ResNet50(input_shape, trainable=False, pop=True, **kwargs):
#importing convolutional layers of ResNet50 from keras
model = ResNet50(include_top=False, weights='imagenet',input_shape=input_shape)
if pop == True:
model.layers.pop() # pop pooling layer
model.layers.pop() # pop last activation layer
#setting the convolutional layers to non-trainable
for layer in model.layers:
layer.trainable = trainable
print('Resnet50 for Perception loss:')
model.summary()
return(model)