Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update VGG model to be compatible with Timm weights and add conversion scripts #1914

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageConverter
from keras_hub.src.models.whisper.whisper_audio_converter import (
WhisperAudioConverter,
)
3 changes: 3 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@
from keras_hub.src.models.text_to_image import TextToImage
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier
from keras_hub.src.models.vgg.vgg_image_classifier import (
VGGImageClassifierPreprocessor,
)
from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone
from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone
from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer
Expand Down
1 change: 1 addition & 0 deletions keras_hub/src/models/vgg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
3 changes: 1 addition & 2 deletions keras_hub/src/models/vgg/vgg_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,11 @@ def __init__(
image_shape=(None, None, 3),
**kwargs,
):

# === Functional Model ===
img_input = keras.layers.Input(shape=image_shape)
x = img_input

for stack_index in range(len(stackwise_num_repeats) - 1):
for stack_index in range(len(stackwise_num_repeats)):
x = apply_vgg_block(
x=x,
num_layers=stackwise_num_repeats[stack_index],
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/vgg/vgg_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_backbone_basics(self):
cls=VGGBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 4, 4, 64),
expected_output_shape=(2, 2, 2, 64),
run_mixed_precision_check=False,
)

Expand Down
66 changes: 52 additions & 14 deletions keras_hub/src/models/vgg/vgg_image_classifier.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.models.image_classifier import ImageClassifier
from keras_hub.src.models.image_classifier_preprocessor import (
ImageClassifierPreprocessor,
)
from keras_hub.src.models.task import Task
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone


@keras_hub_export("keras_hub.layers.VGGImageConverter")
class VGGImageConverter(ImageConverter):
backbone_cls = VGGBackbone


@keras_hub_export("keras_hub.models.VGGImageClassifierPreprocessor")
class VGGImageClassifierPreprocessor(ImageClassifierPreprocessor):
backbone_cls = VGGBackbone
image_converter_cls = VGGImageConverter


@keras_hub_export("keras_hub.models.VGGImageClassifier")
class VGGImageClassifier(ImageClassifier):
"""VGG image classification task.
Expand Down Expand Up @@ -96,13 +111,14 @@ class VGGImageClassifier(ImageClassifier):
"""

backbone_cls = VGGBackbone
preprocessor_cls = VGGImageClassifierPreprocessor

def __init__(
self,
backbone,
num_classes,
preprocessor=None,
pooling="flatten",
pooling="avg",
pooling_hidden_dim=4096,
activation=None,
dropout=0.0,
Expand Down Expand Up @@ -141,24 +157,46 @@ def __init__(
"Unknown `pooling` type. Polling should be either `'avg'` or "
f"`'max'`. Received: pooling={pooling}."
)
self.output_dropout = keras.layers.Dropout(
dropout,
dtype=head_dtype,
name="output_dropout",
)
self.output_dense = keras.layers.Dense(
num_classes,
activation=activation,
dtype=head_dtype,
name="predictions",

self.head = keras.Sequential(
[
keras.layers.Conv2D(
filters=4096,
kernel_size=7,
name="fc1",
activation=activation,
use_bias=True,
padding="same",
),
keras.layers.Dropout(
rate=dropout,
dtype=head_dtype,
name="output_dropout",
),
keras.layers.Conv2D(
filters=4096,
kernel_size=1,
name="fc2",
activation=activation,
use_bias=True,
padding="same",
),
self.pooler,
keras.layers.Dense(
num_classes,
activation=activation,
dtype=head_dtype,
name="predictions",
),
],
name="head",
)

# === Functional Model ===
inputs = self.backbone.input
x = self.backbone(inputs)
x = self.pooler(x)
x = self.output_dropout(x)
outputs = self.output_dense(x)
outputs = self.head(x)

# Skip the parent class functional model.
Task.__init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions keras_hub/src/models/vgg/vgg_image_classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
class VGGImageClassifierTest(TestCase):
def setUp(self):
# Setup model.
self.images = np.ones((2, 4, 4, 3), dtype="float32")
self.images = np.ones((2, 8, 8, 3), dtype="float32")
self.labels = [0, 3]
self.backbone = VGGBackbone(
stackwise_num_repeats=[2, 4, 4],
stackwise_num_filters=[2, 16, 16],
image_shape=(4, 4, 3),
image_shape=(8, 8, 3),
)
self.init_kwargs = {
"backbone": self.backbone,
Expand Down
85 changes: 85 additions & 0 deletions keras_hub/src/utils/timm/convert_vgg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Any

import numpy as np

from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier

backbone_cls = VGGBackbone


REPEATS_BY_SIZE = {
"vgg11": [1, 1, 2, 2, 2],
"vgg13": [2, 2, 2, 2, 2],
"vgg16": [2, 2, 3, 3, 3],
"vgg19": [2, 2, 4, 4, 4],
}


def convert_backbone_config(timm_config):
architecture = timm_config["architecture"]
stackwise_num_repeats = REPEATS_BY_SIZE[architecture]
return dict(
stackwise_num_repeats=stackwise_num_repeats,
stackwise_num_filters=[64, 128, 256, 512, 512],
)


def convert_conv2d(
model,
loader,
keras_layer_name: str,
hf_layer_name: str,
):
loader.port_weight(
model.get_layer(keras_layer_name).kernel,
hf_weight_key=f"{hf_layer_name}.weight",
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
)
loader.port_weight(
model.get_layer(keras_layer_name).bias,
hf_weight_key=f"{hf_layer_name}.bias",
)


def convert_weights(
backbone: VGGBackbone,
loader,
timm_config: dict[Any],
):
architecture = timm_config["architecture"]
stackwise_num_repeats = REPEATS_BY_SIZE[architecture]

hf_index_to_keras_layer_name = {}
layer_index = 0
for block_index, repeats_in_block in enumerate(stackwise_num_repeats):
for repeat_index in range(repeats_in_block):
hf_index = layer_index
layer_index += 2 # Conv + activation layers.
layer_name = f"block{block_index + 1}_conv{repeat_index + 1}"
hf_index_to_keras_layer_name[hf_index] = layer_name
layer_index += 1 # Pooling layer after blocks.

for hf_index, keras_layer_name in hf_index_to_keras_layer_name.items():
convert_conv2d(
backbone, loader, keras_layer_name, f"features.{hf_index}"
)


def convert_head(
task: VGGImageClassifier,
loader,
timm_config: dict[Any],
):
convert_conv2d(task.head, loader, "fc1", "pre_logits.fc1")
convert_conv2d(task.head, loader, "fc2", "pre_logits.fc2")

loader.port_weight(
task.head.get_layer("predictions").kernel,
hf_weight_key="head.fc.weight",
hook_fn=lambda x, _: np.transpose(np.squeeze(x)),
)
loader.port_weight(
task.head.get_layer("predictions").bias,
hf_weight_key="head.fc.bias",
)
3 changes: 3 additions & 0 deletions keras_hub/src/utils/timm/preset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from keras_hub.src.utils.preset_utils import jax_memory_cleanup
from keras_hub.src.utils.timm import convert_densenet
from keras_hub.src.utils.timm import convert_resnet
from keras_hub.src.utils.timm import convert_vgg
from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader


Expand All @@ -16,6 +17,8 @@ def __init__(self, preset, config):
self.converter = convert_resnet
elif "densenet" in architecture:
self.converter = convert_densenet
elif "vgg" in architecture:
self.converter = convert_vgg
else:
raise ValueError(
"KerasHub has no converter for timm models "
Expand Down
116 changes: 116 additions & 0 deletions tools/checkpoint_conversion/convert_vgg_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Loads an external VGG model and saves it in Keras format.

Optionally uploads the model to Keras if the `--upload_uri` flag is passed.

python tools/checkpoint_conversion/convert_vgg_checkpoints.py \
--preset vgg11 --upload_uri kaggle://kerashub/vgg/keras/vgg11
"""

import os
import shutil

import keras
import numpy as np
import PIL
import timm
import torch
from absl import app
from absl import flags

import keras_hub

PRESET_MAP = {
"vgg11": "timm/vgg11.tv_in1k",
"vgg13": "timm/vgg13.tv_in1k",
"vgg16": "timm/vgg16.tv_in1k",
"vgg19": "timm/vgg19.tv_in1k",
# TODO(jeffcarp): Add BN variants.
}


PRESET = flags.DEFINE_string(
"preset",
None,
"Must be a valid `VGG` preset from KerasHub",
required=True,
)
UPLOAD_URI = flags.DEFINE_string(
"upload_uri",
None,
'Could be "kaggle://keras/{variant}/keras/{preset}_int8"',
)


def validate_output(keras_model, timm_model):
file = keras.utils.get_file(
origin=(
"https://storage.googleapis.com/keras-cv/"
"models/paligemma/cow_beach_1.png"
)
)
image = PIL.Image.open(file)
batch = np.array([image])

# Preprocess with Timm.
data_config = timm.data.resolve_model_data_config(timm_model)
data_config["crop_pct"] = 1.0 # Stop timm from cropping.
transforms = timm.data.create_transform(**data_config, is_training=False)
timm_preprocessed = transforms(image)
timm_preprocessed = keras.ops.transpose(timm_preprocessed, axes=(1, 2, 0))
timm_preprocessed = keras.ops.expand_dims(timm_preprocessed, 0)

# Preprocess with Keras.
keras_preprocessed = keras_model.preprocessor(batch)

# Call with Timm. Use the keras preprocessed image so we can keep modeling
# and preprocessing comparisons independent.
timm_batch = keras.ops.transpose(keras_preprocessed, axes=(0, 3, 1, 2))
timm_batch = torch.from_numpy(np.array(timm_batch))
timm_outputs = timm_model(timm_batch).detach().numpy()
timm_label = np.argmax(timm_outputs[0])

# Call with Keras.
keras_outputs = keras_model.predict(batch)
keras_label = np.argmax(keras_outputs[0])

print("🔶 Keras output:", keras_outputs[0, :10])
print("🔶 TIMM output:", timm_outputs[0, :10])
print("🔶 Keras label:", keras_label)
print("🔶 TIMM label:", timm_label)
modeling_diff = np.mean(np.abs(keras_outputs - timm_outputs))
print("🔶 Modeling difference:", modeling_diff)
preprocessing_diff = np.mean(np.abs(keras_preprocessed - timm_preprocessed))
print("🔶 Preprocessing difference:", preprocessing_diff)


def main(_):
preset = PRESET.value
if os.path.exists(preset):
shutil.rmtree(preset)
os.makedirs(preset)

timm_name = PRESET_MAP[preset]

timm_model = timm.create_model(timm_name, pretrained=True)
timm_model = timm_model.eval()
print("✅ Loaded TIMM model.")
print(timm_model)

keras_model = keras_hub.models.ImageClassifier.from_preset(
"hf://" + timm_name,
)
print("✅ Loaded KerasHub model.")

keras_model.save_to_preset(f"./{preset}")
print(f"🏁 Preset saved to ./{preset}")

validate_output(keras_model, timm_model)

upload_uri = UPLOAD_URI.value
if upload_uri:
keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}")
print(f"🏁 Preset uploaded to {upload_uri}")


if __name__ == "__main__":
app.run(main)
Loading