Skip to content

Commit 8606edf

Browse files
sachinprasadhsmattdangerw
authored andcommitted
Add pyramid output for densenet, cspDarknet (#1801)
* add pyramid outputs * fix testcase * format fix * make common testcase for pyramid outputs * change default shape * simplify testcase * test case change and add channel axis
1 parent b9bc61e commit 8606edf

8 files changed

+152
-88
lines changed

keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py

+54-15
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
from keras import layers
1616

1717
from keras_nlp.src.api_export import keras_nlp_export
18-
from keras_nlp.src.models.backbone import Backbone
18+
from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
1919

2020

2121
@keras_nlp_export("keras_nlp.models.CSPDarkNetBackbone")
22-
class CSPDarkNetBackbone(Backbone):
22+
class CSPDarkNetBackbone(FeaturePyramidBackbone):
2323
"""This class represents Keras Backbone of CSPDarkNet model.
2424
2525
This class implements a CSPDarkNet backbone as described in
@@ -65,12 +65,15 @@ def __init__(
6565
self,
6666
stackwise_num_filters,
6767
stackwise_depth,
68-
include_rescaling,
68+
include_rescaling=True,
6969
block_type="basic_block",
70-
image_shape=(224, 224, 3),
70+
image_shape=(None, None, 3),
7171
**kwargs,
7272
):
7373
# === Functional Model ===
74+
channel_axis = (
75+
-1 if keras.config.image_data_format() == "channels_last" else 1
76+
)
7477
apply_ConvBlock = (
7578
apply_darknet_conv_block_depthwise
7679
if block_type == "depthwise_block"
@@ -83,15 +86,22 @@ def __init__(
8386
if include_rescaling:
8487
x = layers.Rescaling(scale=1 / 255.0)(x)
8588

86-
x = apply_focus(name="stem_focus")(x)
89+
x = apply_focus(channel_axis, name="stem_focus")(x)
8790
x = apply_darknet_conv_block(
88-
base_channels, kernel_size=3, strides=1, name="stem_conv"
91+
base_channels,
92+
channel_axis,
93+
kernel_size=3,
94+
strides=1,
95+
name="stem_conv",
8996
)(x)
97+
98+
pyramid_outputs = {}
9099
for index, (channels, depth) in enumerate(
91100
zip(stackwise_num_filters, stackwise_depth)
92101
):
93102
x = apply_ConvBlock(
94103
channels,
104+
channel_axis,
95105
kernel_size=3,
96106
strides=2,
97107
name=f"dark{index + 2}_conv",
@@ -100,17 +110,20 @@ def __init__(
100110
if index == len(stackwise_depth) - 1:
101111
x = apply_spatial_pyramid_pooling_bottleneck(
102112
channels,
113+
channel_axis,
103114
hidden_filters=channels // 2,
104115
name=f"dark{index + 2}_spp",
105116
)(x)
106117

107118
x = apply_cross_stage_partial(
108119
channels,
120+
channel_axis,
109121
num_bottlenecks=depth,
110122
block_type="basic_block",
111123
residual=(index != len(stackwise_depth) - 1),
112124
name=f"dark{index + 2}_csp",
113125
)(x)
126+
pyramid_outputs[f"P{index + 2}"] = x
114127

115128
super().__init__(inputs=image_input, outputs=x, **kwargs)
116129

@@ -120,6 +133,7 @@ def __init__(
120133
self.include_rescaling = include_rescaling
121134
self.block_type = block_type
122135
self.image_shape = image_shape
136+
self.pyramid_outputs = pyramid_outputs
123137

124138
def get_config(self):
125139
config = super().get_config()
@@ -135,7 +149,7 @@ def get_config(self):
135149
return config
136150

137151

138-
def apply_focus(name=None):
152+
def apply_focus(channel_axis, name=None):
139153
"""A block used in CSPDarknet to focus information into channels of the
140154
image.
141155
@@ -151,7 +165,7 @@ def apply_focus(name=None):
151165
"""
152166

153167
def apply(x):
154-
return layers.Concatenate(name=name)(
168+
return layers.Concatenate(axis=channel_axis, name=name)(
155169
[
156170
x[..., ::2, ::2, :],
157171
x[..., 1::2, ::2, :],
@@ -164,7 +178,13 @@ def apply(x):
164178

165179

166180
def apply_darknet_conv_block(
167-
filters, kernel_size, strides, use_bias=False, activation="silu", name=None
181+
filters,
182+
channel_axis,
183+
kernel_size,
184+
strides,
185+
use_bias=False,
186+
activation="silu",
187+
name=None,
168188
):
169189
"""
170190
The basic conv block used in Darknet. Applies Conv2D followed by a
@@ -193,11 +213,12 @@ def apply(inputs):
193213
kernel_size,
194214
strides,
195215
padding="same",
216+
data_format=keras.config.image_data_format(),
196217
use_bias=use_bias,
197218
name=name + "_conv",
198219
)(inputs)
199220

200-
x = layers.BatchNormalization(name=name + "_bn")(x)
221+
x = layers.BatchNormalization(axis=channel_axis, name=name + "_bn")(x)
201222

202223
if activation == "silu":
203224
x = layers.Lambda(lambda x: keras.activations.silu(x))(x)
@@ -212,7 +233,7 @@ def apply(inputs):
212233

213234

214235
def apply_darknet_conv_block_depthwise(
215-
filters, kernel_size, strides, activation="silu", name=None
236+
filters, channel_axis, kernel_size, strides, activation="silu", name=None
216237
):
217238
"""
218239
The depthwise conv block used in CSPDarknet.
@@ -236,9 +257,13 @@ def apply_darknet_conv_block_depthwise(
236257

237258
def apply(inputs):
238259
x = layers.DepthwiseConv2D(
239-
kernel_size, strides, padding="same", use_bias=False
260+
kernel_size,
261+
strides,
262+
padding="same",
263+
data_format=keras.config.image_data_format(),
264+
use_bias=False,
240265
)(inputs)
241-
x = layers.BatchNormalization()(x)
266+
x = layers.BatchNormalization(axis=channel_axis)(x)
242267

243268
if activation == "silu":
244269
x = layers.Lambda(lambda x: keras.activations.swish(x))(x)
@@ -248,7 +273,11 @@ def apply(inputs):
248273
x = layers.LeakyReLU(0.1)(x)
249274

250275
x = apply_darknet_conv_block(
251-
filters, kernel_size=1, strides=1, activation=activation
276+
filters,
277+
channel_axis,
278+
kernel_size=1,
279+
strides=1,
280+
activation=activation,
252281
)(x)
253282

254283
return x
@@ -258,6 +287,7 @@ def apply(inputs):
258287

259288
def apply_spatial_pyramid_pooling_bottleneck(
260289
filters,
290+
channel_axis,
261291
hidden_filters=None,
262292
kernel_sizes=(5, 9, 13),
263293
activation="silu",
@@ -291,6 +321,7 @@ def apply_spatial_pyramid_pooling_bottleneck(
291321
def apply(x):
292322
x = apply_darknet_conv_block(
293323
hidden_filters,
324+
channel_axis,
294325
kernel_size=1,
295326
strides=1,
296327
activation=activation,
@@ -304,13 +335,15 @@ def apply(x):
304335
kernel_size,
305336
strides=1,
306337
padding="same",
338+
data_format=keras.config.image_data_format(),
307339
name=f"{name}_maxpool_{kernel_size}",
308340
)(x[0])
309341
)
310342

311-
x = layers.Concatenate(name=f"{name}_concat")(x)
343+
x = layers.Concatenate(axis=channel_axis, name=f"{name}_concat")(x)
312344
x = apply_darknet_conv_block(
313345
filters,
346+
channel_axis,
314347
kernel_size=1,
315348
strides=1,
316349
activation=activation,
@@ -324,6 +357,7 @@ def apply(x):
324357

325358
def apply_cross_stage_partial(
326359
filters,
360+
channel_axis,
327361
num_bottlenecks,
328362
residual=True,
329363
block_type="basic_block",
@@ -361,6 +395,7 @@ def apply(inputs):
361395

362396
x1 = apply_darknet_conv_block(
363397
hidden_channels,
398+
channel_axis,
364399
kernel_size=1,
365400
strides=1,
366401
activation=activation,
@@ -369,6 +404,7 @@ def apply(inputs):
369404

370405
x2 = apply_darknet_conv_block(
371406
hidden_channels,
407+
channel_axis,
372408
kernel_size=1,
373409
strides=1,
374410
activation=activation,
@@ -379,13 +415,15 @@ def apply(inputs):
379415
residual_x = x1
380416
x1 = apply_darknet_conv_block(
381417
hidden_channels,
418+
channel_axis,
382419
kernel_size=1,
383420
strides=1,
384421
activation=activation,
385422
name=f"{name}_bottleneck_{i}_conv1",
386423
)(x1)
387424
x1 = ConvBlock(
388425
hidden_channels,
426+
channel_axis,
389427
kernel_size=3,
390428
strides=1,
391429
activation=activation,
@@ -399,6 +437,7 @@ def apply(inputs):
399437
x = layers.Concatenate(name=f"{name}_concat")([x1, x2])
400438
x = apply_darknet_conv_block(
401439
filters,
440+
channel_axis,
402441
kernel_size=1,
403442
strides=1,
404443
activation=activation,

keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,26 @@
2424
class CSPDarkNetBackboneTest(TestCase):
2525
def setUp(self):
2626
self.init_kwargs = {
27-
"stackwise_num_filters": [32, 64, 128, 256],
27+
"stackwise_num_filters": [2, 4, 6, 8],
2828
"stackwise_depth": [1, 3, 3, 1],
29-
"include_rescaling": False,
3029
"block_type": "basic_block",
31-
"image_shape": (224, 224, 3),
30+
"image_shape": (32, 32, 3),
3231
}
33-
self.input_data = np.ones((2, 224, 224, 3), dtype="float32")
32+
self.input_size = 32
33+
self.input_data = np.ones(
34+
(2, self.input_size, self.input_size, 3), dtype="float32"
35+
)
3436

3537
def test_backbone_basics(self):
36-
self.run_backbone_test(
38+
self.run_vision_backbone_test(
3739
cls=CSPDarkNetBackbone,
3840
init_kwargs=self.init_kwargs,
3941
input_data=self.input_data,
40-
expected_output_shape=(2, 7, 7, 256),
42+
expected_output_shape=(2, 1, 1, 8),
43+
expected_pyramid_output_keys=["P2", "P3", "P4", "P5"],
44+
expected_pyramid_image_sizes=[(8, 8), (4, 4), (2, 2), (1, 1)],
4145
run_mixed_precision_check=False,
46+
run_data_format_check=False,
4247
)
4348

4449
@pytest.mark.large

0 commit comments

Comments
 (0)