15
15
from keras import layers
16
16
17
17
from keras_nlp .src .api_export import keras_nlp_export
18
- from keras_nlp .src .models .backbone import Backbone
18
+ from keras_nlp .src .models .feature_pyramid_backbone import FeaturePyramidBackbone
19
19
20
20
21
21
@keras_nlp_export ("keras_nlp.models.CSPDarkNetBackbone" )
22
- class CSPDarkNetBackbone (Backbone ):
22
+ class CSPDarkNetBackbone (FeaturePyramidBackbone ):
23
23
"""This class represents Keras Backbone of CSPDarkNet model.
24
24
25
25
This class implements a CSPDarkNet backbone as described in
@@ -65,12 +65,15 @@ def __init__(
65
65
self ,
66
66
stackwise_num_filters ,
67
67
stackwise_depth ,
68
- include_rescaling ,
68
+ include_rescaling = True ,
69
69
block_type = "basic_block" ,
70
- image_shape = (224 , 224 , 3 ),
70
+ image_shape = (None , None , 3 ),
71
71
** kwargs ,
72
72
):
73
73
# === Functional Model ===
74
+ channel_axis = (
75
+ - 1 if keras .config .image_data_format () == "channels_last" else 1
76
+ )
74
77
apply_ConvBlock = (
75
78
apply_darknet_conv_block_depthwise
76
79
if block_type == "depthwise_block"
@@ -83,15 +86,22 @@ def __init__(
83
86
if include_rescaling :
84
87
x = layers .Rescaling (scale = 1 / 255.0 )(x )
85
88
86
- x = apply_focus (name = "stem_focus" )(x )
89
+ x = apply_focus (channel_axis , name = "stem_focus" )(x )
87
90
x = apply_darknet_conv_block (
88
- base_channels , kernel_size = 3 , strides = 1 , name = "stem_conv"
91
+ base_channels ,
92
+ channel_axis ,
93
+ kernel_size = 3 ,
94
+ strides = 1 ,
95
+ name = "stem_conv" ,
89
96
)(x )
97
+
98
+ pyramid_outputs = {}
90
99
for index , (channels , depth ) in enumerate (
91
100
zip (stackwise_num_filters , stackwise_depth )
92
101
):
93
102
x = apply_ConvBlock (
94
103
channels ,
104
+ channel_axis ,
95
105
kernel_size = 3 ,
96
106
strides = 2 ,
97
107
name = f"dark{ index + 2 } _conv" ,
@@ -100,17 +110,20 @@ def __init__(
100
110
if index == len (stackwise_depth ) - 1 :
101
111
x = apply_spatial_pyramid_pooling_bottleneck (
102
112
channels ,
113
+ channel_axis ,
103
114
hidden_filters = channels // 2 ,
104
115
name = f"dark{ index + 2 } _spp" ,
105
116
)(x )
106
117
107
118
x = apply_cross_stage_partial (
108
119
channels ,
120
+ channel_axis ,
109
121
num_bottlenecks = depth ,
110
122
block_type = "basic_block" ,
111
123
residual = (index != len (stackwise_depth ) - 1 ),
112
124
name = f"dark{ index + 2 } _csp" ,
113
125
)(x )
126
+ pyramid_outputs [f"P{ index + 2 } " ] = x
114
127
115
128
super ().__init__ (inputs = image_input , outputs = x , ** kwargs )
116
129
@@ -120,6 +133,7 @@ def __init__(
120
133
self .include_rescaling = include_rescaling
121
134
self .block_type = block_type
122
135
self .image_shape = image_shape
136
+ self .pyramid_outputs = pyramid_outputs
123
137
124
138
def get_config (self ):
125
139
config = super ().get_config ()
@@ -135,7 +149,7 @@ def get_config(self):
135
149
return config
136
150
137
151
138
- def apply_focus (name = None ):
152
+ def apply_focus (channel_axis , name = None ):
139
153
"""A block used in CSPDarknet to focus information into channels of the
140
154
image.
141
155
@@ -151,7 +165,7 @@ def apply_focus(name=None):
151
165
"""
152
166
153
167
def apply (x ):
154
- return layers .Concatenate (name = name )(
168
+ return layers .Concatenate (axis = channel_axis , name = name )(
155
169
[
156
170
x [..., ::2 , ::2 , :],
157
171
x [..., 1 ::2 , ::2 , :],
@@ -164,7 +178,13 @@ def apply(x):
164
178
165
179
166
180
def apply_darknet_conv_block (
167
- filters , kernel_size , strides , use_bias = False , activation = "silu" , name = None
181
+ filters ,
182
+ channel_axis ,
183
+ kernel_size ,
184
+ strides ,
185
+ use_bias = False ,
186
+ activation = "silu" ,
187
+ name = None ,
168
188
):
169
189
"""
170
190
The basic conv block used in Darknet. Applies Conv2D followed by a
@@ -193,11 +213,12 @@ def apply(inputs):
193
213
kernel_size ,
194
214
strides ,
195
215
padding = "same" ,
216
+ data_format = keras .config .image_data_format (),
196
217
use_bias = use_bias ,
197
218
name = name + "_conv" ,
198
219
)(inputs )
199
220
200
- x = layers .BatchNormalization (name = name + "_bn" )(x )
221
+ x = layers .BatchNormalization (axis = channel_axis , name = name + "_bn" )(x )
201
222
202
223
if activation == "silu" :
203
224
x = layers .Lambda (lambda x : keras .activations .silu (x ))(x )
@@ -212,7 +233,7 @@ def apply(inputs):
212
233
213
234
214
235
def apply_darknet_conv_block_depthwise (
215
- filters , kernel_size , strides , activation = "silu" , name = None
236
+ filters , channel_axis , kernel_size , strides , activation = "silu" , name = None
216
237
):
217
238
"""
218
239
The depthwise conv block used in CSPDarknet.
@@ -236,9 +257,13 @@ def apply_darknet_conv_block_depthwise(
236
257
237
258
def apply (inputs ):
238
259
x = layers .DepthwiseConv2D (
239
- kernel_size , strides , padding = "same" , use_bias = False
260
+ kernel_size ,
261
+ strides ,
262
+ padding = "same" ,
263
+ data_format = keras .config .image_data_format (),
264
+ use_bias = False ,
240
265
)(inputs )
241
- x = layers .BatchNormalization ()(x )
266
+ x = layers .BatchNormalization (axis = channel_axis )(x )
242
267
243
268
if activation == "silu" :
244
269
x = layers .Lambda (lambda x : keras .activations .swish (x ))(x )
@@ -248,7 +273,11 @@ def apply(inputs):
248
273
x = layers .LeakyReLU (0.1 )(x )
249
274
250
275
x = apply_darknet_conv_block (
251
- filters , kernel_size = 1 , strides = 1 , activation = activation
276
+ filters ,
277
+ channel_axis ,
278
+ kernel_size = 1 ,
279
+ strides = 1 ,
280
+ activation = activation ,
252
281
)(x )
253
282
254
283
return x
@@ -258,6 +287,7 @@ def apply(inputs):
258
287
259
288
def apply_spatial_pyramid_pooling_bottleneck (
260
289
filters ,
290
+ channel_axis ,
261
291
hidden_filters = None ,
262
292
kernel_sizes = (5 , 9 , 13 ),
263
293
activation = "silu" ,
@@ -291,6 +321,7 @@ def apply_spatial_pyramid_pooling_bottleneck(
291
321
def apply (x ):
292
322
x = apply_darknet_conv_block (
293
323
hidden_filters ,
324
+ channel_axis ,
294
325
kernel_size = 1 ,
295
326
strides = 1 ,
296
327
activation = activation ,
@@ -304,13 +335,15 @@ def apply(x):
304
335
kernel_size ,
305
336
strides = 1 ,
306
337
padding = "same" ,
338
+ data_format = keras .config .image_data_format (),
307
339
name = f"{ name } _maxpool_{ kernel_size } " ,
308
340
)(x [0 ])
309
341
)
310
342
311
- x = layers .Concatenate (name = f"{ name } _concat" )(x )
343
+ x = layers .Concatenate (axis = channel_axis , name = f"{ name } _concat" )(x )
312
344
x = apply_darknet_conv_block (
313
345
filters ,
346
+ channel_axis ,
314
347
kernel_size = 1 ,
315
348
strides = 1 ,
316
349
activation = activation ,
@@ -324,6 +357,7 @@ def apply(x):
324
357
325
358
def apply_cross_stage_partial (
326
359
filters ,
360
+ channel_axis ,
327
361
num_bottlenecks ,
328
362
residual = True ,
329
363
block_type = "basic_block" ,
@@ -361,6 +395,7 @@ def apply(inputs):
361
395
362
396
x1 = apply_darknet_conv_block (
363
397
hidden_channels ,
398
+ channel_axis ,
364
399
kernel_size = 1 ,
365
400
strides = 1 ,
366
401
activation = activation ,
@@ -369,6 +404,7 @@ def apply(inputs):
369
404
370
405
x2 = apply_darknet_conv_block (
371
406
hidden_channels ,
407
+ channel_axis ,
372
408
kernel_size = 1 ,
373
409
strides = 1 ,
374
410
activation = activation ,
@@ -379,13 +415,15 @@ def apply(inputs):
379
415
residual_x = x1
380
416
x1 = apply_darknet_conv_block (
381
417
hidden_channels ,
418
+ channel_axis ,
382
419
kernel_size = 1 ,
383
420
strides = 1 ,
384
421
activation = activation ,
385
422
name = f"{ name } _bottleneck_{ i } _conv1" ,
386
423
)(x1 )
387
424
x1 = ConvBlock (
388
425
hidden_channels ,
426
+ channel_axis ,
389
427
kernel_size = 3 ,
390
428
strides = 1 ,
391
429
activation = activation ,
@@ -399,6 +437,7 @@ def apply(inputs):
399
437
x = layers .Concatenate (name = f"{ name } _concat" )([x1 , x2 ])
400
438
x = apply_darknet_conv_block (
401
439
filters ,
440
+ channel_axis ,
402
441
kernel_size = 1 ,
403
442
strides = 1 ,
404
443
activation = activation ,
0 commit comments