Skip to content

Commit

Permalink
checkpoint conversion added
Browse files Browse the repository at this point in the history
  • Loading branch information
ushareng committed Oct 7, 2024
1 parent 35217ba commit 2e1e9c0
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 142 deletions.
3 changes: 3 additions & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
from keras_hub.src.models.densenet.densenet_image_converter import (
DenseNetImageConverter,
)
from keras_hub.src.models.mobilenet.mobilenet_image_converter import (
MobileNetImageConverter,
)
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
PaliGemmaImageConverter,
)
Expand Down
3 changes: 3 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@
from keras_hub.src.models.mobilenet.mobilenet_image_classifier import (
MobileNetImageClassifier,
)
from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import (
MobileNetImageClassifierPreprocessor,
)
from keras_hub.src.models.opt.opt_backbone import OPTBackbone
from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM
from keras_hub.src.models.opt.opt_causal_lm_preprocessor import (
Expand Down
5 changes: 4 additions & 1 deletion keras_hub/src/models/mobilenet/mobilenet_image_classifier.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.image_classifier import ImageClassifier
from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone

from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import (
MobileNetImageClassifierPreprocessor,
)

@keras_hub_export("keras_hub.models.MobileNetImageClassifier")
class MobileNetImageClassifier(ImageClassifier):
backbone_cls = MobileNetBackbone
preprocessor_cls = MobileNetImageClassifierPreprocessor
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.image_classifier_preprocessor import (
ImageClassifierPreprocessor,
)
from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone
from keras_hub.src.models.mobilenet.mobilenet_image_converter import (
MobileNetImageConverter,
)


@keras_hub_export("keras_hub.models.MobileNetImageClassifierPreprocessor")
class MobileNetImageClassifierPreprocessor(ImageClassifierPreprocessor):
backbone_cls = MobileNetBackbone
image_converter_cls = MobileNetImageConverter
8 changes: 8 additions & 0 deletions keras_hub/src/models/mobilenet/mobilenet_image_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone


@keras_hub_export("keras_hub.layers.MobileNetImageConverter")
class MobileNetImageConverter(ImageConverter):
backbone_cls = MobileNetBackbone
Empty file.
230 changes: 90 additions & 140 deletions keras_hub/src/utils/timm/convert_mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,80 +27,76 @@ def convert_backbone_config(timm_config):
stackwise_num_filters = [[16, 16], [24, 24, 24], [24, 24], [48, 48, 48]]
stackwise_kernel_size = [[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]]
stackwise_num_strides = [[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]]
stackwise_se_ratio = (
[
[None, None],
[0.25, 0.25, 0.25],
[0.3, 0.3],
[0.3, 0.25, 0.25],
],
)
stackwise_activation = (
[
["relu6", "relu6"],
["hard_swish", "hard_swish", "hard_swish"],
["hard_swish", "hard_swish"],
["hard_swish", "hard_swish", "hard_swish"],
],
)
stackwise_se_ratio = [
[None, None],
[0.25, 0.25, 0.25],
[0.3, 0.3],
[0.3, 0.25, 0.25],
]
stackwise_activation = [
["relu", "relu"],
["hard_swish", "hard_swish", "hard_swish"],
["hard_swish", "hard_swish"],
["hard_swish", "hard_swish", "hard_swish"],
]
output_num_filters = 1024
input_num_filters = 16
depthwise_filters = 8
squeeze_and_excite = 0.5
last_layer_filter = 288

elif timm_architecture == "mobilenetv2_050":
stackwise_num_blocks = ([2, 3, 4, 3, 3, 1],)
stackwise_expansion = (
[
[48, 96],
[96, 96, 96],
[96, 192, 192, 192],
[192, 288, 288],
[288, 480, 480],
[480],
],
)
stackwise_num_filters = (
[
[16, 16],
[16, 16, 16],
[32, 32, 32, 32],
[48, 48, 48],
[80, 80, 80],
[160],
],
)
stackwise_kernel_size = (
[[3, 3], [3, 3, 3], [3, 3, 3, 3], [3, 3, 3], [3, 3, 3], [3]],
)
stackwise_num_strides = (
[[2, 1], [2, 1, 1], [2, 1, 1, 1], [1, 1, 1], [2, 1, 1], [1]],
)
stackwise_se_ratio = (
[
[None, None],
[None, None, None],
[None, None, None, None],
[None, None, None],
[None, None, None],
[None],
],
)
stackwise_activation = (
[
["relu6", "relu6"],
["relu6", "relu6", "relu6"],
["relu6", "relu6", "relu6", "relu6"],
["relu6", "relu6", "relu6"],
["relu6", "relu6", "relu6"],
["relu6"],
],
)
output_num_filters = 1280
input_num_filters = 16
depthwise_filters = 8
squeeze_and_excite = None
# elif timm_architecture == "mobilenetv2_050":
# stackwise_num_blocks = ([2, 3, 4, 3, 3, 1],)
# stackwise_expansion = (
# [
# [48, 96],
# [96, 96, 96],
# [96, 192, 192, 192],
# [192, 288, 288],
# [288, 480, 480],
# [480],
# ],
# )
# stackwise_num_filters = (
# [
# [16, 16],
# [16, 16, 16],
# [32, 32, 32, 32],
# [48, 48, 48],
# [80, 80, 80],
# [160],
# ],
# )
# stackwise_kernel_size = (
# [[3, 3], [3, 3, 3], [3, 3, 3, 3], [3, 3, 3], [3, 3, 3], [3]],
# )
# stackwise_num_strides = (
# [[2, 1], [2, 1, 1], [2, 1, 1, 1], [1, 1, 1], [2, 1, 1], [1]],
# )
# stackwise_se_ratio = (
# [
# [None, None],
# [None, None, None],
# [None, None, None, None],
# [None, None, None],
# [None, None, None],
# [None],
# ],
# )
# stackwise_activation = (
# [
# ["relu6", "relu6"],
# ["relu6", "relu6", "relu6"],
# ["relu6", "relu6", "relu6", "relu6"],
# ["relu6", "relu6", "relu6"],
# ["relu6", "relu6", "relu6"],
# ["relu6"],
# ],
# )
# output_num_filters = 1280
# input_num_filters = 16
# depthwise_filters = 8
# squeeze_and_excite = None
else:
raise ValueError(
f"Currently, the architecture {timm_architecture} is not supported."
Expand Down Expand Up @@ -158,7 +154,7 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix):

# DepthWise Block (block 0)
hf_name = "blocks.0.0"
keras_name = "blocks_0"
keras_name = "block_0_0"
port_conv2d(f"{keras_name}_conv1", f"{hf_name}.conv_dw")
port_batch_normalization(f"{keras_name}_bn1", f"{hf_name}.bn1")

Expand All @@ -172,86 +168,40 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix):
num_stacks = len(backbone.stackwise_num_blocks)
for block_idx in range(num_stacks):
for inverted_block in range(backbone.stackwise_num_blocks[block_idx]):
# if version == "v1":
# keras_name = f"stack{stack_index}_block{block_idx}"
# hf_name = f"layer{stack_index+1}.{block_idx}"
# else:
# keras_name = f"stack{stack_index}_block{block_idx}"
# hf_name = f"stages.{stack_index}.blocks.{block_idx}"
keras_name = f"block_{block_idx+1}_{inverted_block}"
hf_name = f"blocks.{block_idx+1}.{inverted_block}"

# ConvBnAct Block
if block_idx == num_stacks - 1 and version == "v3":
port_conv2d(f"{keras_name}_conv", f"{hf_name}.conv")
port_batch_normalization(f"{keras_name}_bn", f"{hf_name}.bn1")

# Inverted Residual Block
else:
port_conv2d(f"{keras_name}_conv1", f"{hf_name}.conv_pw")
port_batch_normalization(f"{keras_name}_bn1", f"{hf_name}.bn1")
port_conv2d(f"{keras_name}_conv2", f"{hf_name}.conv_dw")
port_batch_normalization(f"{keras_name}_bn2", f"{hf_name}.bn2")

if backbone.stackwise_se_ratio[block_idx][inverted_block]:
port_conv2d(
f"{keras_name}_se_conv_reduce",
f"{hf_name}.se.conv_reduce",
)
port_conv2d(
f"{keras_name}_se_conv_expand",
f"{hf_name}.se.conv_expand",
)

port_conv2d(f"{keras_name}_c onv3", f"{hf_name}.conv_pwl")
port_batch_normalization(f"{keras_name}_bn3", f"{hf_name}.bn3")
port_conv2d(f"{keras_name}_conv1", f"{hf_name}.conv_pw")
port_batch_normalization(f"{keras_name}_bn1", f"{hf_name}.bn1")
port_conv2d(f"{keras_name}_conv2", f"{hf_name}.conv_dw")
port_batch_normalization(f"{keras_name}_bn2", f"{hf_name}.bn2")

if backbone.stackwise_se_ratio[block_idx][inverted_block]:
port_conv2d(
f"{keras_name}_se_conv_reduce",
f"{hf_name}.se.conv_reduce",
)
port_conv2d(
f"{keras_name}_se_conv_expand",
f"{hf_name}.se.conv_expand",
)

port_conv2d(f"{keras_name}_conv3", f"{hf_name}.conv_pwl")
port_batch_normalization(f"{keras_name}_bn3", f"{hf_name}.bn3")

# ConvBnAct Block
port_conv2d(f"block_{num_stacks+1}_0_conv", f"blocks.{num_stacks+1}.0.conv")
port_batch_normalization(
f"block_{num_stacks+1}_0_bn", f"blocks.{num_stacks+1}.0.bn1"
)

if version == "v3":
hf_name = f"blocks.{num_stacks+1}.0"
keras_name = "Dfs"
port_conv2d("output_conv", "conv_head")
if version == "v2":
port_batch_normalization("output_batch_norm", "bn2")

# if version == "v1":
# if block_idx == 0 and (
# block_type == "bottleneck_block" or stack_index > 0
# ):
# port_conv2d(
# f"{keras_name}_0_conv", f"{hf_name}.downsample.0"
# )
# port_batch_normalization(
# f"{keras_name}_0_bn", f"{hf_name}.downsample.1"
# )
# port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1")
# port_batch_normalization(f"{keras_name}_1_bn", f"{hf_name}.bn1")
# port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2")
# port_batch_normalization(f"{keras_name}_2_bn", f"{hf_name}.bn2")
# if block_type == "bottleneck_block":
# port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3")
# port_batch_normalization(
# f"{keras_name}_3_bn", f"{hf_name}.bn3"
# )
# else:
# if block_idx == 0 and (
# block_type == "bottleneck_block" or stack_index > 0
# ):
# port_conv2d(
# f"{keras_name}_0_conv", f"{hf_name}.downsample.conv"
# )
# port_batch_normalization(
# f"{keras_name}_pre_activation_bn", f"{hf_name}.norm1"
# )
# port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1")
# port_batch_normalization(
# f"{keras_name}_1_bn", f"{hf_name}.norm2"
# )
# port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2")
# if block_type == "bottleneck_block":
# port_batch_normalization(
# f"{keras_name}_2_bn", f"{hf_name}.norm3"
# )
# port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3")
# if version == "v2":
# port_batch_normalization("output_batch_norm", "bn2")


def convert_head(task, loader, timm_config):
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/utils/timm/convert_mobilenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_convert_mobilenet_classifier(self):
model = ImageClassifier.from_preset(
"hf://timm/mobilenetv3_small_050.lamb_in1k"
)
outputs = model.predict(ops.ones((1, 512, 512, 3)))
outputs = model.predict(ops.ones((1, 224, 224, 3)))
self.assertEqual(outputs.shape, (1, 1000))

# TODO: compare numerics with timm model
Loading

0 comments on commit 2e1e9c0

Please sign in to comment.