Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ushareng committed Oct 10, 2024
1 parent 8032c8c commit 615494d
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 87 deletions.
6 changes: 3 additions & 3 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@
from keras_hub.src.models.densenet.densenet_image_converter import (
DenseNetImageConverter,
)
from keras_hub.src.models.mobilenet.mobilenet_image_converter import (
MobileNetImageConverter,

from keras_hub.src.models.mix_transformer.mix_transformer_image_converter import (
MiTImageConverter,
)
from keras_hub.src.models.mobilenet.mobilenet_image_converter import (
MobileNetImageConverter,
)
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
PaliGemmaImageConverter,
)
Expand Down
59 changes: 41 additions & 18 deletions keras_hub/src/models/mobilenet/mobilenet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
stackwise_num_strides,
stackwise_se_ratio,
stackwise_activation,
stackwise_padding,
output_num_filters,
depthwise_filters,
last_layer_filter,
Expand All @@ -133,11 +134,18 @@ def __init__(
image_input = keras.layers.Input(shape=image_shape)
x = image_input
input_num_filters = adjust_channels(input_num_filters)

pad_width = (
(0, 0), # No padding for batch
(1, 1), # 1 pixel padding for height
(1, 1), # 1 pixel padding for width
(0, 0),
) # No padding for channels
x = ops.pad(x, pad_width=pad_width)
x = keras.layers.Conv2D(
input_num_filters,
kernel_size=3,
strides=(2, 2),
padding="same",
data_format=keras.config.image_data_format(),
use_bias=False,
name="input_conv",
Expand Down Expand Up @@ -166,6 +174,7 @@ def __init__(
stride=stackwise_num_strides[block][inverted_block],
se_ratio=stackwise_se_ratio[block][inverted_block],
activation=stackwise_activation[block][inverted_block],
padding=stackwise_padding[block][inverted_block],
name=f"block_{block+1}_{inverted_block}",
)

Expand All @@ -181,7 +190,6 @@ def __init__(
x = keras.layers.Conv2D(
last_conv_ch,
kernel_size=1,
padding="same",
data_format=keras.config.image_data_format(),
use_bias=False,
name="output_conv",
Expand All @@ -208,6 +216,7 @@ def __init__(
self.stackwise_num_strides = stackwise_num_strides
self.stackwise_se_ratio = stackwise_se_ratio
self.stackwise_activation = stackwise_activation
self.stackwise_padding = stackwise_padding
self.input_num_filters = input_num_filters
self.output_num_filters = output_num_filters
self.depthwise_filters = depthwise_filters
Expand All @@ -228,6 +237,7 @@ def get_config(self):
"stackwise_num_strides": self.stackwise_num_strides,
"stackwise_se_ratio": self.stackwise_se_ratio,
"stackwise_activation": self.stackwise_activation,
"stackwise_padding": self.stackwise_padding,
"image_shape": self.image_shape,
"input_num_filters": self.input_num_filters,
"output_num_filters": self.output_num_filters,
Expand Down Expand Up @@ -278,6 +288,7 @@ def apply_inverted_res_block(
stride,
se_ratio,
activation,
padding,
name=None,
):
"""An Inverted Residual Block.
Expand All @@ -292,6 +303,7 @@ def apply_inverted_res_block(
se_ratio: float, ratio for bottleneck filters. Number of bottleneck
filters = filters * se_ratio.
activation: the activation layer to use.
padding: padding in the conv2d layer
name: string, block label.
Returns:
Expand All @@ -308,7 +320,6 @@ def apply_inverted_res_block(
x = keras.layers.Conv2D(
expanded_channels,
kernel_size=1,
padding="same",
data_format=keras.config.image_data_format(),
use_bias=False,
name=f"{name}_conv1",
Expand All @@ -323,16 +334,25 @@ def apply_inverted_res_block(

x = keras.layers.Activation(activation=activation)(x)

if stride == 2:
x = keras.layers.ZeroPadding2D(
padding=correct_pad_downsample(x, kernel_size),
)(x)
# if stride == 2:
# x = keras.layers.ZeroPadding2D(
# padding=correct_pad_downsample(x, kernel_size),
# )(x)

# pad_width=[[padding, padding], [padding, padding]]
pad_width = (
(0, 0), # No padding for batch
(padding, padding), # 1 pixel padding for height
(padding, padding), # 1 pixel padding for width
(0, 0),
) # No padding for channels
x = ops.pad(x, pad_width=pad_width)

x = keras.layers.Conv2D(
expanded_channels,
kernel_size,
strides=stride,
padding="same" if stride == 1 else "valid",
padding="valid",
groups=expanded_channels,
data_format=keras.config.image_data_format(),
use_bias=False,
Expand Down Expand Up @@ -361,7 +381,6 @@ def apply_inverted_res_block(
x = keras.layers.Conv2D(
filters,
kernel_size=1,
padding="same",
data_format=keras.config.image_data_format(),
use_bias=False,
name=f"{name}_conv3",
Expand All @@ -379,7 +398,7 @@ def apply_inverted_res_block(


def apply_depthwise_conv_block(
x, filters, kernel_size=3, stride=1, se=None, name=None
x, filters, kernel_size=3, stride=2, se=None, name=None
):
"""Adds a depthwise convolution block.
Expand Down Expand Up @@ -410,16 +429,22 @@ def apply_depthwise_conv_block(
infilters = x.shape[channel_axis]
name = f"{name}_0"

if stride == 2:
x = keras.layers.ZeroPadding2D(
padding=correct_pad_downsample(x, kernel_size),
)(x)

# if stride == 2:
# x = keras.layers.ZeroPadding2D(
# padding=correct_pad_downsample(x, kernel_size),
# )(x)
pad_width = (
(0, 0), # No padding for batch
(1, 1), # 1 pixel padding for height
(1, 1), # 1 pixel padding for width
(0, 0),
) # No padding for channels
x = ops.pad(x, pad_width=pad_width)
x = keras.layers.Conv2D(
infilters,
kernel_size,
strides=stride,
padding="same" if stride == 1 else "valid",
padding="valid",
data_format=keras.config.image_data_format(),
groups=infilters,
use_bias=False,
Expand All @@ -446,7 +471,6 @@ def apply_depthwise_conv_block(
x = keras.layers.Conv2D(
filters,
kernel_size=1,
padding="same",
data_format=keras.config.image_data_format(),
use_bias=False,
name=f"{name}_conv2",
Expand Down Expand Up @@ -520,7 +544,6 @@ def ConvBnAct(x, filter, activation, name=None):
x = keras.layers.Conv2D(
filter,
kernel_size=1,
padding="same",
data_format=keras.config.image_data_format(),
use_bias=False,
name=f"{name}_conv",
Expand Down
7 changes: 4 additions & 3 deletions keras_hub/src/models/mobilenet/mobilenet_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def setUp(self):
"stackwise_se_ratio": [
[None, None],
[0.25, 0.25, 0.25],
[0.3, 0.3],
[0.3, 0.25, 0.25],
[0.25, 0.25],
[0.25, 0.25, 0.25],
],
"stackwise_activation": [
["relu", "relu"],
Expand All @@ -47,6 +47,7 @@ def setUp(self):
["hard_swish", "hard_swish", "hard_swish"],
["hard_swish"],
],
"stackwise_padding": [[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]],
"output_num_filters": 1024,
"input_activation": "hard_swish",
"output_activation": "hard_swish",
Expand All @@ -63,7 +64,7 @@ def test_backbone_basics(self):
cls=MobileNetBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 14, 14, 1024),
expected_output_shape=(2, 7, 7, 1024),
run_mixed_precision_check=False,
run_data_format_check=False,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
MobileNetImageClassifierPreprocessor,
)


@keras_hub_export("keras_hub.models.MobileNetImageClassifier")
class MobileNetImageClassifier(ImageClassifier):
backbone_cls = MobileNetBackbone
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@ def setUp(self):
stackwise_se_ratio=[
[None, None],
[0.25, 0.25, 0.25],
[0.3, 0.3],
[0.3, 0.25, 0.25],
[0.25, 0.25],
[0.25, 0.25, 0.25],
],
stackwise_activation=[
["relu", "relu"],
["hard_swish", "hard_swish", "hard_swish"],
["hard_swish", "hard_swish"],
["hard_swish", "hard_swish", "hard_swish"],
],
stackwise_padding=[[1, 1], [2, 2, 2], [2, 2], [2, 2, 2], [1]],
output_num_filters=1024,
input_activation="hard_swish",
output_activation="hard_swish",
Expand Down
70 changes: 10 additions & 60 deletions keras_hub/src/utils/timm/convert_mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,73 +30,21 @@ def convert_backbone_config(timm_config):
stackwise_se_ratio = [
[None, None],
[0.25, 0.25, 0.25],
[0.3, 0.3],
[0.3, 0.25, 0.25],
[0.25, 0.25],
[0.25, 0.25, 0.25],
]
stackwise_activation = [
["relu", "relu"],
["hard_swish", "hard_swish", "hard_swish"],
["hard_swish", "hard_swish"],
["hard_swish", "hard_swish", "hard_swish"],
]
stackwise_padding = [[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]]
output_num_filters = 1024
input_num_filters = 16
depthwise_filters = 8
squeeze_and_excite = 0.5
last_layer_filter = 288

# elif timm_architecture == "mobilenetv2_050":
# stackwise_num_blocks = ([2, 3, 4, 3, 3, 1],)
# stackwise_expansion = (
# [
# [48, 96],
# [96, 96, 96],
# [96, 192, 192, 192],
# [192, 288, 288],
# [288, 480, 480],
# [480],
# ],
# )
# stackwise_num_filters = (
# [
# [16, 16],
# [16, 16, 16],
# [32, 32, 32, 32],
# [48, 48, 48],
# [80, 80, 80],
# [160],
# ],
# )
# stackwise_kernel_size = (
# [[3, 3], [3, 3, 3], [3, 3, 3, 3], [3, 3, 3], [3, 3, 3], [3]],
# )
# stackwise_num_strides = (
# [[2, 1], [2, 1, 1], [2, 1, 1, 1], [1, 1, 1], [2, 1, 1], [1]],
# )
# stackwise_se_ratio = (
# [
# [None, None],
# [None, None, None],
# [None, None, None, None],
# [None, None, None],
# [None, None, None],
# [None],
# ],
# )
# stackwise_activation = (
# [
# ["relu6", "relu6"],
# ["relu6", "relu6", "relu6"],
# ["relu6", "relu6", "relu6", "relu6"],
# ["relu6", "relu6", "relu6"],
# ["relu6", "relu6", "relu6"],
# ["relu6"],
# ],
# )
# output_num_filters = 1280
# input_num_filters = 16
# depthwise_filters = 8
# squeeze_and_excite = None
else:
raise ValueError(
f"Currently, the architecture {timm_architecture} is not supported."
Expand All @@ -114,6 +62,7 @@ def convert_backbone_config(timm_config):
stackwise_num_strides=stackwise_num_strides,
stackwise_se_ratio=stackwise_se_ratio,
stackwise_activation=stackwise_activation,
stackwise_padding=stackwise_padding,
output_num_filters=output_num_filters,
output_activation=output_activation,
last_layer_filter=last_layer_filter,
Expand All @@ -122,13 +71,15 @@ def convert_backbone_config(timm_config):

def convert_weights(backbone, loader, timm_config):
def port_conv2d(keras_layer_name, hf_weight_prefix):
print(f"porting weights {hf_weight_prefix} -> {keras_layer_name}")
loader.port_weight(
backbone.get_layer(keras_layer_name).kernel,
hf_weight_key=f"{hf_weight_prefix}.weight",
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
)

def port_batch_normalization(keras_layer_name, hf_weight_prefix):
print(f"porting weights {hf_weight_prefix} -> {keras_layer_name}")
loader.port_weight(
backbone.get_layer(keras_layer_name).gamma,
hf_weight_key=f"{hf_weight_prefix}.weight",
Expand All @@ -145,8 +96,10 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix):
backbone.get_layer(keras_layer_name).moving_variance,
hf_weight_key=f"{hf_weight_prefix}.running_var",
)

version = "v3" if backbone.output_activation == "hard_swish" else "v2"
loader.port_weight(
backbone.get_layer(keras_layer_name).moving_variance,
hf_weight_key=f"{hf_weight_prefix}.running_var",
)

# Stem
port_conv2d("input_conv", "conv_stem")
Expand Down Expand Up @@ -196,9 +149,6 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix):
f"block_{num_stacks+1}_0_bn", f"blocks.{num_stacks+1}.0.bn1"
)

if version == "v3":
hf_name = f"blocks.{num_stacks+1}.0"
keras_name = "Dfs"
port_conv2d("output_conv", "conv_head")
# if version == "v2":
# port_batch_normalization("output_batch_norm", "bn2")
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/utils/timm/convert_mobilenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_convert_mobilenet_backbone(self):
"hf://timm/mobilenetv3_small_050.lamb_in1k"
)
outputs = model.predict(ops.ones((1, 224, 224, 3)))
self.assertEqual(outputs.shape, (1, 14, 14, 1024))
self.assertEqual(outputs.shape, (1, 7, 7, 1024))

@pytest.mark.large
def test_convert_mobilenet_classifier(self):
Expand Down

0 comments on commit 615494d

Please sign in to comment.