-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathunet3plus.py
104 lines (73 loc) · 4.46 KB
/
unet3plus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
UNet3+ base model
"""
import tensorflow as tf
import tensorflow.keras as k
from .unet3plus_utils import conv_block
def unet3plus(encoder_layer, output_channels, filters):
""" UNet3+ base model """
""" Encoder """
e1 = encoder_layer[0]
e2 = encoder_layer[1]
e3 = encoder_layer[2]
e4 = encoder_layer[3]
e5 = encoder_layer[4]
""" Decoder """
cat_channels = filters[0]
cat_blocks = len(filters)
upsample_channels = cat_blocks * cat_channels
""" d4 """
e1_d4 = k.layers.MaxPool2D(pool_size=(8, 8))(e1) # 320*320*64 --> 40*40*64
e1_d4 = conv_block(e1_d4, cat_channels, n=1) # 320*320*64 --> 40*40*64
e2_d4 = k.layers.MaxPool2D(pool_size=(4, 4))(e2) # 160*160*128 --> 40*40*128
e2_d4 = conv_block(e2_d4, cat_channels, n=1) # 160*160*128 --> 40*40*64
e3_d4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 80*80*256 --> 40*40*256
e3_d4 = conv_block(e3_d4, cat_channels, n=1) # 80*80*256 --> 40*40*64
e4_d4 = conv_block(e4, cat_channels, n=1) # 40*40*512 --> 40*40*64
e5_d4 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(e5) # 80*80*256 --> 40*40*256
e5_d4 = conv_block(e5_d4, cat_channels, n=1) # 20*20*1024 --> 20*20*64
d4 = k.layers.concatenate([e1_d4, e2_d4, e3_d4, e4_d4, e5_d4])
d4 = conv_block(d4, upsample_channels, n=1) # 40*40*320 --> 40*40*320
""" d3 """
e1_d3 = k.layers.MaxPool2D(pool_size=(4, 4))(e1) # 320*320*64 --> 80*80*64
e1_d3 = conv_block(e1_d3, cat_channels, n=1) # 80*80*64 --> 80*80*64
e2_d3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 160*160*256 --> 80*80*256
e2_d3 = conv_block(e2_d3, cat_channels, n=1) # 80*80*256 --> 80*80*64
e3_d3 = conv_block(e3, cat_channels, n=1) # 80*80*512 --> 80*80*64
e4_d3 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d4) # 40*40*320 --> 80*80*320
e4_d3 = conv_block(e4_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64
e5_d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(e5) # 20*20*320 --> 80*80*320
e5_d3 = conv_block(e5_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64
d3 = k.layers.concatenate([e1_d3, e2_d3, e3_d3, e4_d3, e5_d3])
d3 = conv_block(d3, upsample_channels, n=1) # 80*80*320 --> 80*80*320
""" d2 """
e1_d2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 320*320*64 --> 160*160*64
e1_d2 = conv_block(e1_d2, cat_channels, n=1) # 160*160*64 --> 160*160*64
e2_d2 = conv_block(e2, cat_channels, n=1) # 160*160*256 --> 160*160*64
d3_d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d3) # 80*80*320 --> 160*160*320
d3_d2 = conv_block(d3_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
d4_d2 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d4) # 40*40*320 --> 160*160*320
d4_d2 = conv_block(d4_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
e5_d2 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(e5) # 20*20*320 --> 160*160*320
e5_d2 = conv_block(e5_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
d2 = k.layers.concatenate([e1_d2, e2_d2, d3_d2, d4_d2, e5_d2])
d2 = conv_block(d2, upsample_channels, n=1) # 160*160*320 --> 160*160*320
""" d1 """
e1_d1 = conv_block(e1, cat_channels, n=1) # 320*320*64 --> 320*320*64
d2_d1 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) # 160*160*320 --> 320*320*320
d2_d1 = conv_block(d2_d1, cat_channels, n=1) # 160*160*320 --> 160*160*64
d3_d1 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) # 80*80*320 --> 320*320*320
d3_d1 = conv_block(d3_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
d4_d1 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) # 40*40*320 --> 320*320*320
d4_d1 = conv_block(d4_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
e5_d1 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) # 20*20*320 --> 320*320*320
e5_d1 = conv_block(e5_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
d1 = k.layers.concatenate([e1_d1, d2_d1, d3_d1, d4_d1, e5_d1, ])
d1 = conv_block(d1, upsample_channels, n=1) # 320*320*320 --> 320*320*320
# last layer does not have batchnorm and relu
d = conv_block(d1, output_channels, n=1, is_bn=False, is_relu=False)
if output_channels == 1:
output = k.layers.Activation('sigmoid', dtype='float32')(d)
else:
output = k.layers.Activation('softmax', dtype='float32')(d)
return output, 'UNet_3Plus'