Skip to content

Add pyramid output for densenet, cspDarknet #1801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 54 additions & 15 deletions keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from keras import layers

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.backbone import Backbone
from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone


@keras_nlp_export("keras_nlp.models.CSPDarkNetBackbone")
class CSPDarkNetBackbone(Backbone):
class CSPDarkNetBackbone(FeaturePyramidBackbone):
"""This class represents Keras Backbone of CSPDarkNet model.

This class implements a CSPDarkNet backbone as described in
Expand Down Expand Up @@ -65,12 +65,15 @@ def __init__(
self,
stackwise_num_filters,
stackwise_depth,
include_rescaling,
include_rescaling=True,
block_type="basic_block",
image_shape=(224, 224, 3),
image_shape=(None, None, 3),
**kwargs,
):
# === Functional Model ===
channel_axis = (
-1 if keras.config.image_data_format() == "channels_last" else 1
)
apply_ConvBlock = (
apply_darknet_conv_block_depthwise
if block_type == "depthwise_block"
Expand All @@ -83,15 +86,22 @@ def __init__(
if include_rescaling:
x = layers.Rescaling(scale=1 / 255.0)(x)

x = apply_focus(name="stem_focus")(x)
x = apply_focus(channel_axis, name="stem_focus")(x)
x = apply_darknet_conv_block(
base_channels, kernel_size=3, strides=1, name="stem_conv"
base_channels,
channel_axis,
kernel_size=3,
strides=1,
name="stem_conv",
)(x)

pyramid_outputs = {}
for index, (channels, depth) in enumerate(
zip(stackwise_num_filters, stackwise_depth)
):
x = apply_ConvBlock(
channels,
channel_axis,
kernel_size=3,
strides=2,
name=f"dark{index + 2}_conv",
Expand All @@ -100,17 +110,20 @@ def __init__(
if index == len(stackwise_depth) - 1:
x = apply_spatial_pyramid_pooling_bottleneck(
channels,
channel_axis,
hidden_filters=channels // 2,
name=f"dark{index + 2}_spp",
)(x)

x = apply_cross_stage_partial(
channels,
channel_axis,
num_bottlenecks=depth,
block_type="basic_block",
residual=(index != len(stackwise_depth) - 1),
name=f"dark{index + 2}_csp",
)(x)
pyramid_outputs[f"P{index + 2}"] = x

super().__init__(inputs=image_input, outputs=x, **kwargs)

Expand All @@ -120,6 +133,7 @@ def __init__(
self.include_rescaling = include_rescaling
self.block_type = block_type
self.image_shape = image_shape
self.pyramid_outputs = pyramid_outputs

def get_config(self):
config = super().get_config()
Expand All @@ -135,7 +149,7 @@ def get_config(self):
return config


def apply_focus(name=None):
def apply_focus(channel_axis, name=None):
"""A block used in CSPDarknet to focus information into channels of the
image.

Expand All @@ -151,7 +165,7 @@ def apply_focus(name=None):
"""

def apply(x):
return layers.Concatenate(name=name)(
return layers.Concatenate(axis=channel_axis, name=name)(
[
x[..., ::2, ::2, :],
x[..., 1::2, ::2, :],
Expand All @@ -164,7 +178,13 @@ def apply(x):


def apply_darknet_conv_block(
filters, kernel_size, strides, use_bias=False, activation="silu", name=None
filters,
channel_axis,
kernel_size,
strides,
use_bias=False,
activation="silu",
name=None,
):
"""
The basic conv block used in Darknet. Applies Conv2D followed by a
Expand Down Expand Up @@ -193,11 +213,12 @@ def apply(inputs):
kernel_size,
strides,
padding="same",
data_format=keras.config.image_data_format(),
use_bias=use_bias,
name=name + "_conv",
)(inputs)

x = layers.BatchNormalization(name=name + "_bn")(x)
x = layers.BatchNormalization(axis=channel_axis, name=name + "_bn")(x)

if activation == "silu":
x = layers.Lambda(lambda x: keras.activations.silu(x))(x)
Expand All @@ -212,7 +233,7 @@ def apply(inputs):


def apply_darknet_conv_block_depthwise(
filters, kernel_size, strides, activation="silu", name=None
filters, channel_axis, kernel_size, strides, activation="silu", name=None
):
"""
The depthwise conv block used in CSPDarknet.
Expand All @@ -236,9 +257,13 @@ def apply_darknet_conv_block_depthwise(

def apply(inputs):
x = layers.DepthwiseConv2D(
kernel_size, strides, padding="same", use_bias=False
kernel_size,
strides,
padding="same",
data_format=keras.config.image_data_format(),
use_bias=False,
)(inputs)
x = layers.BatchNormalization()(x)
x = layers.BatchNormalization(axis=channel_axis)(x)

if activation == "silu":
x = layers.Lambda(lambda x: keras.activations.swish(x))(x)
Expand All @@ -248,7 +273,11 @@ def apply(inputs):
x = layers.LeakyReLU(0.1)(x)

x = apply_darknet_conv_block(
filters, kernel_size=1, strides=1, activation=activation
filters,
channel_axis,
kernel_size=1,
strides=1,
activation=activation,
)(x)

return x
Expand All @@ -258,6 +287,7 @@ def apply(inputs):

def apply_spatial_pyramid_pooling_bottleneck(
filters,
channel_axis,
hidden_filters=None,
kernel_sizes=(5, 9, 13),
activation="silu",
Expand Down Expand Up @@ -291,6 +321,7 @@ def apply_spatial_pyramid_pooling_bottleneck(
def apply(x):
x = apply_darknet_conv_block(
hidden_filters,
channel_axis,
kernel_size=1,
strides=1,
activation=activation,
Expand All @@ -304,13 +335,15 @@ def apply(x):
kernel_size,
strides=1,
padding="same",
data_format=keras.config.image_data_format(),
name=f"{name}_maxpool_{kernel_size}",
)(x[0])
)

x = layers.Concatenate(name=f"{name}_concat")(x)
x = layers.Concatenate(axis=channel_axis, name=f"{name}_concat")(x)
x = apply_darknet_conv_block(
filters,
channel_axis,
kernel_size=1,
strides=1,
activation=activation,
Expand All @@ -324,6 +357,7 @@ def apply(x):

def apply_cross_stage_partial(
filters,
channel_axis,
num_bottlenecks,
residual=True,
block_type="basic_block",
Expand Down Expand Up @@ -361,6 +395,7 @@ def apply(inputs):

x1 = apply_darknet_conv_block(
hidden_channels,
channel_axis,
kernel_size=1,
strides=1,
activation=activation,
Expand All @@ -369,6 +404,7 @@ def apply(inputs):

x2 = apply_darknet_conv_block(
hidden_channels,
channel_axis,
kernel_size=1,
strides=1,
activation=activation,
Expand All @@ -379,13 +415,15 @@ def apply(inputs):
residual_x = x1
x1 = apply_darknet_conv_block(
hidden_channels,
channel_axis,
kernel_size=1,
strides=1,
activation=activation,
name=f"{name}_bottleneck_{i}_conv1",
)(x1)
x1 = ConvBlock(
hidden_channels,
channel_axis,
kernel_size=3,
strides=1,
activation=activation,
Expand All @@ -399,6 +437,7 @@ def apply(inputs):
x = layers.Concatenate(name=f"{name}_concat")([x1, x2])
x = apply_darknet_conv_block(
filters,
channel_axis,
kernel_size=1,
strides=1,
activation=activation,
Expand Down
17 changes: 11 additions & 6 deletions keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,26 @@
class CSPDarkNetBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"stackwise_num_filters": [32, 64, 128, 256],
"stackwise_num_filters": [2, 4, 6, 8],
"stackwise_depth": [1, 3, 3, 1],
"include_rescaling": False,
"block_type": "basic_block",
"image_shape": (224, 224, 3),
"image_shape": (32, 32, 3),
}
self.input_data = np.ones((2, 224, 224, 3), dtype="float32")
self.input_size = 32
self.input_data = np.ones(
(2, self.input_size, self.input_size, 3), dtype="float32"
)

def test_backbone_basics(self):
self.run_backbone_test(
self.run_vision_backbone_test(
cls=CSPDarkNetBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 7, 7, 256),
expected_output_shape=(2, 1, 1, 8),
expected_pyramid_output_keys=["P2", "P3", "P4", "P5"],
expected_pyramid_image_sizes=[(8, 8), (4, 4), (2, 2), (1, 1)],
run_mixed_precision_check=False,
run_data_format_check=False,
)

@pytest.mark.large
Expand Down
Loading
Loading