From 741b8895ddb9209dbf82533427e8501b934cf6ee Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 13 Nov 2024 18:16:06 -0800 Subject: [PATCH 01/33] vit base --- keras_hub/src/models/vit/vit_backbone.py | 52 +++++ keras_hub/src/models/vit/vit_layers.py | 245 +++++++++++++++++++++++ 2 files changed, 297 insertions(+) create mode 100644 keras_hub/src/models/vit/vit_backbone.py create mode 100644 keras_hub/src/models/vit/vit_layers.py diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py new file mode 100644 index 0000000000..65a3c73775 --- /dev/null +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -0,0 +1,52 @@ +import keras +from keras import ops + +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.utils.keras_utils import standardize_data_format + + + + + + +class ViTBackbone(Backbone): + def __init__( + self, + image_shape, + patch_size, + num_layers, + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + layer_norm_epsilon=1e-6, + data_format=None, + dtype=None, + **kwargs, + ): + data_format = standardize_data_format(data_format) + h_axis, w_axis = ( + (-3, -2) if data_format == "channels_last" else (-2, -1) + ) + # Check that the input image is well specified. + if image_shape[h_axis] is None or image_shape[w_axis] is None: + raise ValueError( + f"Image shape must have defined height and width. Found `None` " + f"at index {h_axis} (height) or {w_axis} (width). " + f"Image shape: {image_shape}" + ) + if image_shape[h_axis] != image_shape[w_axis]: + raise ValueError( + f"Image height and width must be equal. Found height: " + f"{image_shape[h_axis]}, width: {image_shape[w_axis]} at " + f"indices {h_axis} and {w_axis} respectively. Image shape: " + f"{image_shape}" + ) + + # === Layers === + patch_and_embedding = ViTPatchingAndEmbedding( + kernel_size=(patch_size, patch_size), + strides=(patch_size, patch_size), + embed_dim=hidden_dim, + ) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py new file mode 100644 index 0000000000..643798522e --- /dev/null +++ b/keras_hub/src/models/vit/vit_layers.py @@ -0,0 +1,245 @@ +import keras +from keras import ops + +from keras_hub.src.utils.keras_utils import standardize_data_format + + +class TokenLayer(keras.layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def build(self, input_shape): + self.cls_token = self.add_weight( + name="cls", + shape=(1, 1, input_shape[-1]), + initializer="zeros", + dtype=self.dtype_policy, + name="cls_token", + ) + self.built = True + + def call(self, inputs): + cls_token = self.cls_token + keras.ops.zeros_like(inputs[:, 0:1]) + out = keras.ops.concatenate([cls_token, inputs], axis=1) + + return out + + +class MLP(keras.layers.Layer): + def __init__( + self, + hidden_dim, + mlp_dim, + use_bias=True, + dropout_rate=0.0, + dtype=None, + **kwargs, + ): + super().__init__(**kwargs) + + # === config === + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.use_bias = use_bias + self.dropout_rate = dropout_rate + + def build(self, input_shape): + self.dense1 = keras.layers.Dense( + units=self.mlp_dim, + use_bias=self.use_bias, + activation="gelu", + bias_initializer=( + keras.initializers.RandomNormal(stddev=1e-6) + if self.use_bias + else None + ), + dtype=self.dtype_policy, + name="dense_1", + ) + self.dense1.build(input_shape) + self.dense2 = keras.layers.Dense( + units=self.hidden_dim, + use_bias=self.use_bias, + bias_initializer=( + keras.initializers.RandomNormal(stddev=1e-6) + if self.use_bias + else None + ), + dtype=self.dtype_policy, + name="dense_2", + ) + self.dense2.build((None, None, self.mlp_dim)) + self.dropout = keras.layers.Dropout(self.dropout_rate) + self.built = True + + def call(self, inputs): + x = self.dense1(inputs) + x = self.dense2(x) + out = self.dropout(x) + return out + + +class ViTPatchingAndEmbedding(keras.layers.Layer): + def __init__( + self, + image_size, + patch_size, + hidden_dim, + num_channels=3, + dtype=None, + **kwargs, + ): + super().__init__(**kwargs) + num_patches = (image_size // patch_size) ** 2 + num_positions = num_patches + 1 + + # === config === + self.hidden_dim = hidden_dim + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.num_positions = num_positions + self.dtype = dtype + + def build(self, input_shape): + self.patch_embedding = keras.layers.Conv2D( + filters=self.hidden_dim, + kernel_size=self.patch_size, + strides=self.patch_size, + padding="valid", + activation=None, + dtype=self.dtype_policy, + name="patch_embedding", + ) + self.patch_embedding.build(input_shape) + self.position_embedding = keras.layers.Embedding( + self.num_positions, + self.hidden_dim, + dtype=self.dtype_policy, + name="position_embedding", + ) + self.position_embedding.build([1, self.num_positions]) + self.position_ids = ops.expand_dims( + ops.arange(self.num_positions), axis=0 + ) + self.built = True + + def call(self, input_tokens): + x = self.patch_embedding(input_tokens) + input_shape = ops.shape(x) + x = ops.reshape(x, [input_shape[0], -1, input_shape[-1]]) + x = x + self.position_embedding(self.position_ids) + return x + + def compute_output_shape(self, input_shape): + return ( + input_shape[0], + self.num_positions, + self.hidden_dim, + ) + + +class ViTEncoderBlock(keras.layers.Layer): + def __init__( + self, + num_heads, + hidden_dim, + mlp_dim, + dropout_rate, + attention_dropout, + layer_norm_epsilon, + **kwargs, + ): + super().__init__(**kwargs) + + key_dim = hidden_dim // num_heads + + # === config === + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.key_dim = key_dim + self.mlp_dim = mlp_dim + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + + def build(self, input_shape): + # Attention block + self.layer_norm_1 = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, name="ln_1" + ) + self.layer_norm_1.build(input_shape) + self.mha = keras.layers.MultiHeadAttention( + num_heads=self.num_heads, + key_dim=self.key_dim, + use_bias=False, + dropout=self.attention_dropout, + name="mha", + ) + self.mha.build(input_shape, input_shape) + self.dropout = keras.layers.Dropout(self.dropout_rate) + + # MLP block + self.layer_norm_2 = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, name="ln_2" + ) + self.layer_norm_2.build((None, None, self.hidden_dim)) + self.mlp = MLP( + hidden_dim=self.hidden_dim, mlp_dim=self.mlp_dim, name="mlp" + ) + self.mlp((None, None, self.hidden_dim)) + self.built = True + + def call(self, inputs): + x = self.layer_norm_1(inputs) + x = self.mha(x, x) + x = self.dropout(x) + x = x + inputs + + y = self.layer_norm_2(x) + y = self.mlp(y) + + return x + y + + +class ViTEncoder(keras.layers.Layer): + def __init__( + self, + num_layers, + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + layer_norm_epsilon=1e-6, + **kwargs, + ): + super().__init__(**kwargs) + + # === config === + self.num_layers = num_layers + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + + def build(self, input_shape): + layers = [] + for i in range(self.num_layers): + encoder_block = ViTEncoderBlock( + num_heads=self.num_heads, + hidden_dim=self.hidden_dim, + mlp_dim=self.mlp_dim, + attention_dropout=self.attention_dropout, + layer_norm_epsilon=self.layer_norm_epsilon, + name=f"tranformer_block_{i+1}", + ) + encoder_block.build((None, None, self.hidden_dim)) + layers.append(encoder_block) + + encoder_layers = keras.Sequential(layers) + + From 13dae08f12308675b6f0b6ab308b8c63815c559f Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 15 Nov 2024 14:07:10 -0800 Subject: [PATCH 02/33] Add vit backbone, classifier and preprocessor layers --- keras_hub/api/layers/__init__.py | 1 + keras_hub/api/models/__init__.py | 5 ++ keras_hub/src/models/vit/__init__.py | 0 keras_hub/src/models/vit/vit_backbone.py | 35 ++++++++--- .../src/models/vit/vit_image_classifier.py | 61 +++++++++++++++++++ .../vit/vit_image_classifier_preprocessor.py | 12 ++++ .../src/models/vit/vit_image_converter.py | 8 +++ keras_hub/src/models/vit/vit_layers.py | 54 +++++++++++----- 8 files changed, 154 insertions(+), 22 deletions(-) create mode 100644 keras_hub/src/models/vit/__init__.py create mode 100644 keras_hub/src/models/vit/vit_image_classifier.py create mode 100644 keras_hub/src/models/vit/vit_image_classifier_preprocessor.py create mode 100644 keras_hub/src/models/vit/vit_image_converter.py diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 2a29cdb64e..09cbb4c4dc 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -62,6 +62,7 @@ SegFormerImageConverter, ) from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter +from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter from keras_hub.src.models.whisper.whisper_audio_converter import ( WhisperAudioConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index dd85a97a45..e43acc1360 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -325,6 +325,11 @@ from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( VGGImageClassifierPreprocessor, ) +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier +from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( + ViTImageClassifierPreprocessor, +) from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer diff --git a/keras_hub/src/models/vit/__init__.py b/keras_hub/src/models/vit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index 65a3c73775..7dc74997ef 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -1,14 +1,13 @@ import keras -from keras import ops +from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.vit.vit_layers import ViTEncoder +from keras_hub.src.models.vit.vit_layers import ViTPatchingAndEmbedding from keras_hub.src.utils.keras_utils import standardize_data_format - - - - +@keras_hub_export("keras_hub.models.ViTBackbone") class ViTBackbone(Backbone): def __init__( self, @@ -44,9 +43,31 @@ def __init__( f"{image_shape}" ) - # === Layers === - patch_and_embedding = ViTPatchingAndEmbedding( + # === Functional Model === + inputs = keras.layers.Input(shape=image_shape) + + x = ViTPatchingAndEmbedding( kernel_size=(patch_size, patch_size), strides=(patch_size, patch_size), embed_dim=hidden_dim, + dtype=dtype, + )(inputs) + + x = ViTEncoder( + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + layer_norm_epsilon=layer_norm_epsilon, + dtype=dtype, + )(x) + + output = x[:, 0] + + super().__init__( + inputs=inputs, + outputs=output, + **kwargs, ) diff --git a/keras_hub/src/models/vit/vit_image_classifier.py b/keras_hub/src/models/vit/vit_image_classifier.py new file mode 100644 index 0000000000..e24312e5fc --- /dev/null +++ b/keras_hub/src/models/vit/vit_image_classifier.py @@ -0,0 +1,61 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( + ViTImageClassifierPreprocessor, +) + + +@keras_hub_export("keras_hub.models.ViTImageClassifier") +class ViTImageClassifier(ImageClassifier): + backbone_cls = ViTBackbone + preprocessor_cls = ViTImageClassifierPreprocessor + + def __init__( + self, + backbone, + num_classes, + preprocessor=None, + activation=None, + head_dtype=None, + **kwargs, + ): + head_dtype = head_dtype or backbone.dtype_policy + + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + dtype=head_dtype, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "pooling": self.pooling, + } + ) + return config diff --git a/keras_hub/src/models/vit/vit_image_classifier_preprocessor.py b/keras_hub/src/models/vit/vit_image_classifier_preprocessor.py new file mode 100644 index 0000000000..7e50918eb6 --- /dev/null +++ b/keras_hub/src/models/vit/vit_image_classifier_preprocessor.py @@ -0,0 +1,12 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter + + +@keras_hub_export("keras_hub.models.ViTImageClassifierPreprocessor") +class ViTImageClassifierPreprocessor(ImageClassifierPreprocessor): + backbone_cls = ViTBackbone + image_converter_cls = ViTImageConverter diff --git a/keras_hub/src/models/vit/vit_image_converter.py b/keras_hub/src/models/vit/vit_image_converter.py new file mode 100644 index 0000000000..79d3007eaa --- /dev/null +++ b/keras_hub/src/models/vit/vit_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.vit.vit_backbone import ViTBackbone + + +@keras_hub_export("keras_hub.layers.ViTImageConverter") +class ViTImageConverter(ImageConverter): + backbone_cls = ViTBackbone diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 643798522e..8212e09d3f 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -1,3 +1,5 @@ +import math + import keras from keras import ops @@ -10,7 +12,6 @@ def __init__(self, **kwargs): def build(self, input_shape): self.cls_token = self.add_weight( - name="cls", shape=(1, 1, input_shape[-1]), initializer="zeros", dtype=self.dtype_policy, @@ -32,7 +33,6 @@ def __init__( mlp_dim, use_bias=True, dropout_rate=0.0, - dtype=None, **kwargs, ): super().__init__(**kwargs) @@ -69,7 +69,7 @@ def build(self, input_shape): name="dense_2", ) self.dense2.build((None, None, self.mlp_dim)) - self.dropout = keras.layers.Dropout(self.dropout_rate) + self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") self.built = True def call(self, inputs): @@ -86,7 +86,7 @@ def __init__( patch_size, hidden_dim, num_channels=3, - dtype=None, + data_format=None, **kwargs, ): super().__init__(**kwargs) @@ -100,7 +100,7 @@ def __init__( self.num_channels = num_channels self.num_patches = num_patches self.num_positions = num_positions - self.dtype = dtype + self.data_format = standardize_data_format(data_format) def build(self, input_shape): self.patch_embedding = keras.layers.Conv2D( @@ -109,10 +109,15 @@ def build(self, input_shape): strides=self.patch_size, padding="valid", activation=None, + kernel_initializer=keras.initializers.RandomNormal( + stddev=math.sqrt(1 / (3 * self.patch_size * self.patch_size)), + ), dtype=self.dtype_policy, + data_format=self.data_format, name="patch_embedding", ) self.patch_embedding.build(input_shape) + self.token_layer = TokenLayer(dtype=self.dtype_policy) self.position_embedding = keras.layers.Embedding( self.num_positions, self.hidden_dim, @@ -125,10 +130,13 @@ def build(self, input_shape): ) self.built = True - def call(self, input_tokens): - x = self.patch_embedding(input_tokens) - input_shape = ops.shape(x) + def call(self, inputs): + x = self.patch_embedding(inputs) + input_shape = ops.shape(x) # (N, H, W, C) or (N, C, H, W) + if self.data_format == "channels_first": + x = ops.transpose(x, axes=(0, 2, 3, 1)) x = ops.reshape(x, [input_shape[0], -1, input_shape[-1]]) + x = self.token_layer(x) x = x + self.position_embedding(self.position_ids) return x @@ -167,7 +175,9 @@ def __init__( def build(self, input_shape): # Attention block self.layer_norm_1 = keras.layers.LayerNormalization( - epsilon=self.layer_norm_epsilon, name="ln_1" + epsilon=self.layer_norm_epsilon, + name="ln_1", + dtype=self.dtype_policy, ) self.layer_norm_1.build(input_shape) self.mha = keras.layers.MultiHeadAttention( @@ -176,17 +186,23 @@ def build(self, input_shape): use_bias=False, dropout=self.attention_dropout, name="mha", + dtype=self.dtype_policy, ) self.mha.build(input_shape, input_shape) - self.dropout = keras.layers.Dropout(self.dropout_rate) + self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") # MLP block self.layer_norm_2 = keras.layers.LayerNormalization( - epsilon=self.layer_norm_epsilon, name="ln_2" + epsilon=self.layer_norm_epsilon, + name="ln_2", + dtype=self.dtype_policy, ) self.layer_norm_2.build((None, None, self.hidden_dim)) self.mlp = MLP( - hidden_dim=self.hidden_dim, mlp_dim=self.mlp_dim, name="mlp" + hidden_dim=self.hidden_dim, + mlp_dim=self.mlp_dim, + name="mlp", + dtype=self.dtype_policy, ) self.mlp((None, None, self.hidden_dim)) self.built = True @@ -239,7 +255,15 @@ def build(self, input_shape): ) encoder_block.build((None, None, self.hidden_dim)) layers.append(encoder_block) - - encoder_layers = keras.Sequential(layers) - + self.encoder_layers = keras.Sequential(layers, name="encoder_layers") + self.layer_norm = keras.layers.Normalization( + self.layer_norm_epsilon, name="ln" + ) + self.layer_norm.build((None, None, self.hidden_dim)) + + def call(self, inputs): + x = self.dropout(inputs) + x = self.encoder_layers(x) + x = self.layer_norm(x) + return x From b64b137bb02d6981b5905962f35fae9797e22095 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 15 Nov 2024 15:34:29 -0800 Subject: [PATCH 03/33] update args --- keras_hub/src/models/vit/vit_backbone.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index 7dc74997ef..66302e128e 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -17,8 +17,8 @@ def __init__( num_heads, hidden_dim, mlp_dim, - dropout, - attention_dropout, + dropout=0.0, + attention_dropout=0.0, layer_norm_epsilon=1e-6, data_format=None, dtype=None, @@ -47,9 +47,9 @@ def __init__( inputs = keras.layers.Input(shape=image_shape) x = ViTPatchingAndEmbedding( - kernel_size=(patch_size, patch_size), - strides=(patch_size, patch_size), - embed_dim=hidden_dim, + image_size=image_shape[h_axis], + patch_size=patch_size, + hidden_dim=hidden_dim, dtype=dtype, )(inputs) From 429d6357ecca43c4d9804e27b6aa9636f50c2b58 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 15 Nov 2024 15:44:59 -0800 Subject: [PATCH 04/33] add default args --- keras_hub/src/models/vit/vit_backbone.py | 4 ++-- keras_hub/src/models/vit/vit_layers.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index 66302e128e..693bec4e2b 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -17,7 +17,7 @@ def __init__( num_heads, hidden_dim, mlp_dim, - dropout=0.0, + dropout_rate=0.0, attention_dropout=0.0, layer_norm_epsilon=1e-6, data_format=None, @@ -58,7 +58,7 @@ def __init__( num_heads=num_heads, hidden_dim=hidden_dim, mlp_dim=mlp_dim, - dropout=dropout, + dropout_rate=dropout_rate, attention_dropout=attention_dropout, layer_norm_epsilon=layer_norm_epsilon, dtype=dtype, diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 8212e09d3f..bfd1c35af8 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -154,9 +154,9 @@ def __init__( num_heads, hidden_dim, mlp_dim, - dropout_rate, - attention_dropout, - layer_norm_epsilon, + dropout_rate=0.0, + attention_dropout=0.0, + layer_norm_epsilon=1e-6, **kwargs, ): super().__init__(**kwargs) @@ -226,8 +226,8 @@ def __init__( num_heads, hidden_dim, mlp_dim, - dropout, - attention_dropout, + dropout_rate=0.0, + attention_dropout=0.0, layer_norm_epsilon=1e-6, **kwargs, ): @@ -238,7 +238,7 @@ def __init__( self.num_heads = num_heads self.hidden_dim = hidden_dim self.mlp_dim = mlp_dim - self.dropout = dropout + self.dropout_rate = dropout_rate self.attention_dropout = attention_dropout self.layer_norm_epsilon = layer_norm_epsilon @@ -249,13 +249,14 @@ def build(self, input_shape): num_heads=self.num_heads, hidden_dim=self.hidden_dim, mlp_dim=self.mlp_dim, + dropout_rate=self.dropout_rate, attention_dropout=self.attention_dropout, layer_norm_epsilon=self.layer_norm_epsilon, name=f"tranformer_block_{i+1}", ) encoder_block.build((None, None, self.hidden_dim)) layers.append(encoder_block) - + self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") self.encoder_layers = keras.Sequential(layers, name="encoder_layers") self.layer_norm = keras.layers.Normalization( self.layer_norm_epsilon, name="ln" From 6d69abcda2fbea6d9b3f015f2edec69cfc4acac2 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 15 Nov 2024 15:46:59 -0800 Subject: [PATCH 05/33] correct build method --- keras_hub/src/models/vit/vit_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index bfd1c35af8..2796554144 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -204,7 +204,7 @@ def build(self, input_shape): name="mlp", dtype=self.dtype_policy, ) - self.mlp((None, None, self.hidden_dim)) + self.mlp.build((None, None, self.hidden_dim)) self.built = True def call(self, inputs): From 2e878846d1081857a114e372d27d7533595cd2a5 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 15 Nov 2024 15:51:40 -0800 Subject: [PATCH 06/33] fix build issues --- keras_hub/src/models/vit/vit_layers.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 2796554144..22322d3f3d 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -252,14 +252,17 @@ def build(self, input_shape): dropout_rate=self.dropout_rate, attention_dropout=self.attention_dropout, layer_norm_epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, name=f"tranformer_block_{i+1}", ) encoder_block.build((None, None, self.hidden_dim)) layers.append(encoder_block) self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") self.encoder_layers = keras.Sequential(layers, name="encoder_layers") - self.layer_norm = keras.layers.Normalization( - self.layer_norm_epsilon, name="ln" + self.layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="ln", ) self.layer_norm.build((None, None, self.hidden_dim)) From bd3cce0a1e4d4d69d1f42b64b7f482a474144151 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 15 Nov 2024 16:01:09 -0800 Subject: [PATCH 07/33] fix bugs --- keras_hub/src/models/image_classifier.py | 9 +++- keras_hub/src/models/vit/vit_backbone.py | 4 +- .../src/models/vit/vit_image_classifier.py | 49 ------------------- 3 files changed, 8 insertions(+), 54 deletions(-) diff --git a/keras_hub/src/models/image_classifier.py b/keras_hub/src/models/image_classifier.py index e75e390899..ceafa76cb8 100644 --- a/keras_hub/src/models/image_classifier.py +++ b/keras_hub/src/models/image_classifier.py @@ -117,10 +117,12 @@ def __init__( dtype=head_dtype, name="pooler", ) + elif pooling == "token": + self.pooler = None else: raise ValueError( "Unknown `pooling` type. Polling should be either `'avg'` or " - f"`'max'`. Received: pooling={pooling}." + f"`'max' or 'token'`. Received: pooling={pooling}." ) self.output_dropout = keras.layers.Dropout( dropout, @@ -137,7 +139,10 @@ def __init__( # === Functional Model === inputs = self.backbone.input x = self.backbone(inputs) - x = self.pooler(x) + if pooling == "token": # used for Vision Transformer(ViT) + x = x[:, 0] + else: + x = self.pooler(x) x = self.output_dropout(x) outputs = self.output_dense(x) super().__init__( diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index 693bec4e2b..33e8b610ba 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -53,7 +53,7 @@ def __init__( dtype=dtype, )(inputs) - x = ViTEncoder( + output = ViTEncoder( num_layers=num_layers, num_heads=num_heads, hidden_dim=hidden_dim, @@ -64,8 +64,6 @@ def __init__( dtype=dtype, )(x) - output = x[:, 0] - super().__init__( inputs=inputs, outputs=output, diff --git a/keras_hub/src/models/vit/vit_image_classifier.py b/keras_hub/src/models/vit/vit_image_classifier.py index e24312e5fc..1aab26c0d2 100644 --- a/keras_hub/src/models/vit/vit_image_classifier.py +++ b/keras_hub/src/models/vit/vit_image_classifier.py @@ -1,5 +1,3 @@ -import keras - from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.vit.vit_backbone import ViTBackbone @@ -12,50 +10,3 @@ class ViTImageClassifier(ImageClassifier): backbone_cls = ViTBackbone preprocessor_cls = ViTImageClassifierPreprocessor - - def __init__( - self, - backbone, - num_classes, - preprocessor=None, - activation=None, - head_dtype=None, - **kwargs, - ): - head_dtype = head_dtype or backbone.dtype_policy - - # === Layers === - self.backbone = backbone - self.preprocessor = preprocessor - - self.output_dense = keras.layers.Dense( - num_classes, - activation=activation, - dtype=head_dtype, - name="predictions", - ) - - # === Functional Model === - inputs = self.backbone.input - x = self.backbone(inputs) - outputs = self.output_dense(x) - super().__init__( - inputs=inputs, - outputs=outputs, - **kwargs, - ) - - # === Config === - self.num_classes = num_classes - self.activation = activation - - def get_config(self): - # Backbone serialized in `super` - config = super().get_config() - config.update( - { - "num_classes": self.num_classes, - "pooling": self.pooling, - } - ) - return config From 4232a0659656ee0912cfb44e455785238240b334 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 18 Nov 2024 12:23:00 -0800 Subject: [PATCH 08/33] Update backbone args and configs --- keras_hub/src/models/vit/vit_backbone.py | 38 +++++++++++- keras_hub/src/models/vit/vit_layers.py | 76 ++++++++++++++++++++---- 2 files changed, 102 insertions(+), 12 deletions(-) diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index 33e8b610ba..19840de2d2 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -25,8 +25,8 @@ def __init__( **kwargs, ): data_format = standardize_data_format(data_format) - h_axis, w_axis = ( - (-3, -2) if data_format == "channels_last" else (-2, -1) + h_axis, w_axis, channels_axis = ( + (-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3) ) # Check that the input image is well specified. if image_shape[h_axis] is None or image_shape[w_axis] is None: @@ -43,6 +43,8 @@ def __init__( f"{image_shape}" ) + num_channels = image_shape[channels_axis] + # === Functional Model === inputs = keras.layers.Input(shape=image_shape) @@ -50,7 +52,9 @@ def __init__( image_size=image_shape[h_axis], patch_size=patch_size, hidden_dim=hidden_dim, + num_channels=num_channels, dtype=dtype, + name="vit_patching_and_embedding", )(inputs) output = ViTEncoder( @@ -62,6 +66,7 @@ def __init__( attention_dropout=attention_dropout, layer_norm_epsilon=layer_norm_epsilon, dtype=dtype, + name="vit_encoder", )(x) super().__init__( @@ -69,3 +74,32 @@ def __init__( outputs=output, **kwargs, ) + + # === Config === + self.image_shape = image_shape + self.patch_size = patch_size + self.num_layers = num_layers + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.data_format = data_format + + def get_config(self): + config = super().get_config() + config.update( + { + "image_shape": self.image_shape, + "patch_size": self.patch_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "mlp_dim": self.mlp_dim, + "dropout_rate": self.dropout_rate, + "attention_dropout": self.attention_dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 22322d3f3d..5afc1f46d7 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -1,5 +1,3 @@ -import math - import keras from keras import ops @@ -37,7 +35,7 @@ def __init__( ): super().__init__(**kwargs) - # === config === + # === Config === self.hidden_dim = hidden_dim self.mlp_dim = mlp_dim self.use_bias = use_bias @@ -93,10 +91,10 @@ def __init__( num_patches = (image_size // patch_size) ** 2 num_positions = num_patches + 1 - # === config === - self.hidden_dim = hidden_dim + # === Config === self.image_size = image_size self.patch_size = patch_size + self.hidden_dim = hidden_dim self.num_channels = num_channels self.num_patches = num_patches self.num_positions = num_positions @@ -109,9 +107,6 @@ def build(self, input_shape): strides=self.patch_size, padding="valid", activation=None, - kernel_initializer=keras.initializers.RandomNormal( - stddev=math.sqrt(1 / (3 * self.patch_size * self.patch_size)), - ), dtype=self.dtype_policy, data_format=self.data_format, name="patch_embedding", @@ -122,6 +117,7 @@ def build(self, input_shape): self.num_positions, self.hidden_dim, dtype=self.dtype_policy, + embeddings_initializer=keras.initializers.RandomNormal(stddev=0.02), name="position_embedding", ) self.position_embedding.build([1, self.num_positions]) @@ -147,6 +143,20 @@ def compute_output_shape(self, input_shape): self.hidden_dim, ) + def get_config(self): + config = super().get_config() + config.update( + { + "image_size": self.image_size, + "patch_size": self.patch_size, + "hidden_dim": self.hidden_dim, + "num_channels": self.num_channels, + "num_patches": self.num_patches, + "num_positions": self.num_positions, + } + ) + return config + class ViTEncoderBlock(keras.layers.Layer): def __init__( @@ -154,6 +164,8 @@ def __init__( num_heads, hidden_dim, mlp_dim, + use_mha_bias=True, + use_mlp_bias=True, dropout_rate=0.0, attention_dropout=0.0, layer_norm_epsilon=1e-6, @@ -163,11 +175,13 @@ def __init__( key_dim = hidden_dim // num_heads - # === config === + # === Config === self.num_heads = num_heads self.hidden_dim = hidden_dim self.key_dim = key_dim self.mlp_dim = mlp_dim + self.use_mha_bias = use_mha_bias + self.use_mlp_bias = use_mlp_bias self.dropout_rate = dropout_rate self.attention_dropout = attention_dropout self.layer_norm_epsilon = layer_norm_epsilon @@ -183,7 +197,7 @@ def build(self, input_shape): self.mha = keras.layers.MultiHeadAttention( num_heads=self.num_heads, key_dim=self.key_dim, - use_bias=False, + use_bias=self.use_mha_bias, dropout=self.attention_dropout, name="mha", dtype=self.dtype_policy, @@ -201,6 +215,7 @@ def build(self, input_shape): self.mlp = MLP( hidden_dim=self.hidden_dim, mlp_dim=self.mlp_dim, + use_bias=self.use_mlp_bias, name="mlp", dtype=self.dtype_policy, ) @@ -218,6 +233,23 @@ def call(self, inputs): return x + y + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "key_dim": self.key_dim, + "mlp_dim": self.mlp_dim, + "use_mha_bias": self.use_mha_bias, + "use_mlp_bias": self.use_mlp_bias, + "dropout_rate": self.dropout_rate, + "attention_dropout": self.attention_dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config + class ViTEncoder(keras.layers.Layer): def __init__( @@ -226,6 +258,8 @@ def __init__( num_heads, hidden_dim, mlp_dim, + use_mha_bias=True, + use_mlp_bias=True, dropout_rate=0.0, attention_dropout=0.0, layer_norm_epsilon=1e-6, @@ -238,6 +272,8 @@ def __init__( self.num_heads = num_heads self.hidden_dim = hidden_dim self.mlp_dim = mlp_dim + self.use_mha_bias = use_mha_bias + self.use_mlp_bias = use_mlp_bias self.dropout_rate = dropout_rate self.attention_dropout = attention_dropout self.layer_norm_epsilon = layer_norm_epsilon @@ -250,6 +286,8 @@ def build(self, input_shape): hidden_dim=self.hidden_dim, mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate, + use_mha_bias=self.use_mha_bias, + use_mlp_bias=self.use_mlp_bias, attention_dropout=self.attention_dropout, layer_norm_epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, @@ -265,9 +303,27 @@ def build(self, input_shape): name="ln", ) self.layer_norm.build((None, None, self.hidden_dim)) + self.built = True def call(self, inputs): x = self.dropout(inputs) x = self.encoder_layers(x) x = self.layer_norm(x) return x + + def get_config(self): + config = super().get_config() + config.update( + { + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "mlp_dim": self.mlp_dim, + "use_mha_bias": self.use_mha_bias, + "use_mlp_bias": self.use_mlp_bias, + "dropout_rate": self.dropout_rate, + "attention_dropout": self.attention_dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config From 32b08c5bfecde6ad8e0c1c06a20859f2c1902ed9 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 18 Nov 2024 12:32:27 -0800 Subject: [PATCH 09/33] correct position ids dtype --- keras_hub/src/models/vit/vit_layers.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 5afc1f46d7..af782d6643 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -121,8 +121,15 @@ def build(self, input_shape): name="position_embedding", ) self.position_embedding.build([1, self.num_positions]) - self.position_ids = ops.expand_dims( - ops.arange(self.num_positions), axis=0 + self.position_ids = self.add_weight( + shape=(1, self.num_positions), + initializer="zeros", + # Let the backend determine the int dtype. For example, tf + # requires int64 for correct device placement, whereas jax and torch + # don't. + dtype=int, + trainable=False, + name="position_ids", ) self.built = True From cc938c68839c7047c7dfe4fb6317089bac757024 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 18 Nov 2024 13:33:09 -0800 Subject: [PATCH 10/33] build token layer --- keras_hub/src/models/vit/vit_layers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index af782d6643..46cad86e6a 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -113,6 +113,7 @@ def build(self, input_shape): ) self.patch_embedding.build(input_shape) self.token_layer = TokenLayer(dtype=self.dtype_policy) + self.build(input_shape) self.position_embedding = keras.layers.Embedding( self.num_positions, self.hidden_dim, From 78812ded1f54514c437c32f8030a3d00c4f56cb6 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 18 Nov 2024 13:33:41 -0800 Subject: [PATCH 11/33] token layer build --- keras_hub/src/models/vit/vit_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 46cad86e6a..5da27df65e 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -113,7 +113,7 @@ def build(self, input_shape): ) self.patch_embedding.build(input_shape) self.token_layer = TokenLayer(dtype=self.dtype_policy) - self.build(input_shape) + self.token_layer.build(input_shape) self.position_embedding = keras.layers.Embedding( self.num_positions, self.hidden_dim, From 8a2046525118a5ea5d2a1c26b5e2f3b66a752362 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 18 Nov 2024 13:38:27 -0800 Subject: [PATCH 12/33] assign correct dtype to TokenLayer --- keras_hub/src/models/vit/vit_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 5da27df65e..6d60396bf6 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -12,7 +12,7 @@ def build(self, input_shape): self.cls_token = self.add_weight( shape=(1, 1, input_shape[-1]), initializer="zeros", - dtype=self.dtype_policy, + dtype=self.dtype, name="cls_token", ) self.built = True From de754cca099c7858b50ee92cca4f4d387b3becef Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 18 Nov 2024 13:53:34 -0800 Subject: [PATCH 13/33] fix build shape of token layer --- keras_hub/src/models/vit/vit_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 6d60396bf6..237a93ffd9 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -113,7 +113,7 @@ def build(self, input_shape): ) self.patch_embedding.build(input_shape) self.token_layer = TokenLayer(dtype=self.dtype_policy) - self.token_layer.build(input_shape) + self.token_layer.build((None, None, self.hidden_dim)) self.position_embedding = keras.layers.Embedding( self.num_positions, self.hidden_dim, From 84ba8968617c87d92e9713779fd429cc2680da46 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 18 Nov 2024 15:09:06 -0800 Subject: [PATCH 14/33] correct mlp dens var names --- keras_hub/src/models/vit/vit_layers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 237a93ffd9..26ae7c5cec 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -42,7 +42,7 @@ def __init__( self.dropout_rate = dropout_rate def build(self, input_shape): - self.dense1 = keras.layers.Dense( + self.dense_1 = keras.layers.Dense( units=self.mlp_dim, use_bias=self.use_bias, activation="gelu", @@ -54,8 +54,8 @@ def build(self, input_shape): dtype=self.dtype_policy, name="dense_1", ) - self.dense1.build(input_shape) - self.dense2 = keras.layers.Dense( + self.dense_1.build(input_shape) + self.dense_2 = keras.layers.Dense( units=self.hidden_dim, use_bias=self.use_bias, bias_initializer=( @@ -66,13 +66,13 @@ def build(self, input_shape): dtype=self.dtype_policy, name="dense_2", ) - self.dense2.build((None, None, self.mlp_dim)) + self.dense_2.build((None, None, self.mlp_dim)) self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") self.built = True def call(self, inputs): - x = self.dense1(inputs) - x = self.dense2(x) + x = self.dense_1(inputs) + x = self.dense_2(x) out = self.dropout(x) return out From 7a70e161bedca0233040d9145a190b761617f471 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 18 Nov 2024 15:44:02 -0800 Subject: [PATCH 15/33] use default norm mean and std as per hugging face config --- .../src/models/vit/vit_image_converter.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/keras_hub/src/models/vit/vit_image_converter.py b/keras_hub/src/models/vit/vit_image_converter.py index 79d3007eaa..705c8a8b46 100644 --- a/keras_hub/src/models/vit/vit_image_converter.py +++ b/keras_hub/src/models/vit/vit_image_converter.py @@ -1,8 +1,37 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.utils.tensor_utils import preprocessing_function @keras_hub_export("keras_hub.layers.ViTImageConverter") class ViTImageConverter(ImageConverter): backbone_cls = ViTBackbone + + def __init__( + self, norm_mean=[0.5, 0.5, 0.5], norm_std=[0.5, 0.5, 0.5], **kwargs + ): + super().__init__(**kwargs) + self.norm_mean = norm_mean + self.norm_std = norm_std + + @preprocessing_function + def call(self, inputs): + x = super().call(inputs) + # By default normalize using imagenet mean and std + if self.norm_mean: + x = x - self._expand_non_channel_dims(self.norm_mean, x) + if self.norm_std: + x = x / self._expand_non_channel_dims(self.norm_std, x) + + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "norm_mean": self.norm_mean, + "norm_std": self.norm_std, + } + ) + return config From 81e3021fc6284b01338f969e736185e3f4e57964 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 18 Nov 2024 16:36:16 -0800 Subject: [PATCH 16/33] correct position_ids --- keras_hub/src/models/vit/vit_layers.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 26ae7c5cec..4be0662883 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -122,15 +122,8 @@ def build(self, input_shape): name="position_embedding", ) self.position_embedding.build([1, self.num_positions]) - self.position_ids = self.add_weight( - shape=(1, self.num_positions), - initializer="zeros", - # Let the backend determine the int dtype. For example, tf - # requires int64 for correct device placement, whereas jax and torch - # don't. - dtype=int, - trainable=False, - name="position_ids", + self.position_ids = keras.ops.expand_dims( + keras.ops.arange(self.num_positions), axis=0 ) self.built = True From d3061d6210c55639a9b9d090995641011c547d3f Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Tue, 19 Nov 2024 11:18:24 -0800 Subject: [PATCH 17/33] remove separate token layer --- keras_hub/src/models/vit/vit_layers.py | 61 +++++++++++++------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 4be0662883..3ebf10a9f5 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -4,26 +4,6 @@ from keras_hub.src.utils.keras_utils import standardize_data_format -class TokenLayer(keras.layers.Layer): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def build(self, input_shape): - self.cls_token = self.add_weight( - shape=(1, 1, input_shape[-1]), - initializer="zeros", - dtype=self.dtype, - name="cls_token", - ) - self.built = True - - def call(self, inputs): - cls_token = self.cls_token + keras.ops.zeros_like(inputs[:, 0:1]) - out = keras.ops.concatenate([cls_token, inputs], axis=1) - - return out - - class MLP(keras.layers.Layer): def __init__( self, @@ -101,6 +81,12 @@ def __init__( self.data_format = standardize_data_format(data_format) def build(self, input_shape): + self.class_token = self.add_weight( + shape=(self.hidden_dim,), + initializer="random_normal", + dtype=self.variable_dtype, + name="class_token", + ) self.patch_embedding = keras.layers.Conv2D( filters=self.hidden_dim, kernel_size=self.patch_size, @@ -112,8 +98,6 @@ def build(self, input_shape): name="patch_embedding", ) self.patch_embedding.build(input_shape) - self.token_layer = TokenLayer(dtype=self.dtype_policy) - self.token_layer.build((None, None, self.hidden_dim)) self.position_embedding = keras.layers.Embedding( self.num_positions, self.hidden_dim, @@ -122,20 +106,35 @@ def build(self, input_shape): name="position_embedding", ) self.position_embedding.build([1, self.num_positions]) - self.position_ids = keras.ops.expand_dims( - keras.ops.arange(self.num_positions), axis=0 + self.position_ids = self.add_weight( + shape=(1, self.num_positions), + initializer="zeros", + # Let the backend determine the int dtype. For example, tf + # requires int64 for correct device placement, whereas jax and torch + # don't. + dtype=int, + trainable=False, + name="position_ids", ) self.built = True def call(self, inputs): - x = self.patch_embedding(inputs) - input_shape = ops.shape(x) # (N, H, W, C) or (N, C, H, W) + patch_embeddings = self.patch_embedding(inputs) + input_shape = ops.shape( + patch_embeddings + ) # (N, H, W, C) or (N, C, H, W) if self.data_format == "channels_first": - x = ops.transpose(x, axes=(0, 2, 3, 1)) - x = ops.reshape(x, [input_shape[0], -1, input_shape[-1]]) - x = self.token_layer(x) - x = x + self.position_embedding(self.position_ids) - return x + patch_embeddings = ops.transpose( + patch_embeddings, axes=(0, 2, 3, 1) + ) + patch_embeddings = ops.reshape( + patch_embeddings, [input_shape[0], -1, input_shape[-1]] + ) + class_token = ops.expand_dims(self.class_token, axis=(0, 1)) + class_token = ops.tile(class_token, (input_shape[0], 1, 1)) + position_embeddings = self.position_embedding(self.position_ids) + embeddings = ops.concatenate([class_token, patch_embeddings], axis=1) + return ops.add(embeddings, position_embeddings) def compute_output_shape(self, input_shape): return ( From 618e163cb5b8bcc6d36fcc5e44c89acbd4560d48 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Tue, 19 Nov 2024 11:25:17 -0800 Subject: [PATCH 18/33] correct position ids --- keras_hub/src/models/vit/vit_layers.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 3ebf10a9f5..7ea04d9793 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -106,15 +106,8 @@ def build(self, input_shape): name="position_embedding", ) self.position_embedding.build([1, self.num_positions]) - self.position_ids = self.add_weight( - shape=(1, self.num_positions), - initializer="zeros", - # Let the backend determine the int dtype. For example, tf - # requires int64 for correct device placement, whereas jax and torch - # don't. - dtype=int, - trainable=False, - name="position_ids", + self.position_ids = keras.ops.expand_dims( + keras.ops.arange(self.num_positions), axis=0 ) self.built = True From 2338637658d1faef98e69601788bdea931ce1966 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 21 Nov 2024 10:27:59 -0800 Subject: [PATCH 19/33] Checkpoint conversion script and minor changes --- keras_hub/src/models/vit/vit_backbone.py | 4 + .../src/models/vit/vit_image_classifier.py | 3 + keras_hub/src/models/vit/vit_layers.py | 16 +- .../convert_vit_checkpoints.py | 321 ++++++++++++++++++ 4 files changed, 337 insertions(+), 7 deletions(-) create mode 100644 tools/checkpoint_conversion/convert_vit_checkpoints.py diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index 19840de2d2..027be5aa28 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -20,6 +20,8 @@ def __init__( dropout_rate=0.0, attention_dropout=0.0, layer_norm_epsilon=1e-6, + use_mha_bias=True, + use_mlp_bias=True, data_format=None, dtype=None, **kwargs, @@ -65,6 +67,8 @@ def __init__( dropout_rate=dropout_rate, attention_dropout=attention_dropout, layer_norm_epsilon=layer_norm_epsilon, + use_mha_bias=use_mha_bias, + use_mlp_bias=use_mlp_bias, dtype=dtype, name="vit_encoder", )(x) diff --git a/keras_hub/src/models/vit/vit_image_classifier.py b/keras_hub/src/models/vit/vit_image_classifier.py index 1aab26c0d2..579538b6b9 100644 --- a/keras_hub/src/models/vit/vit_image_classifier.py +++ b/keras_hub/src/models/vit/vit_image_classifier.py @@ -10,3 +10,6 @@ class ViTImageClassifier(ImageClassifier): backbone_cls = ViTBackbone preprocessor_cls = ViTImageClassifierPreprocessor + + def __init__(self, pooling="token", **kwargs): + super().__init__(pooling=pooling, **kwargs) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 7ea04d9793..eef58ca49e 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -82,7 +82,11 @@ def __init__( def build(self, input_shape): self.class_token = self.add_weight( - shape=(self.hidden_dim,), + shape=( + 1, + 1, + self.hidden_dim, + ), initializer="random_normal", dtype=self.variable_dtype, name="class_token", @@ -105,7 +109,7 @@ def build(self, input_shape): embeddings_initializer=keras.initializers.RandomNormal(stddev=0.02), name="position_embedding", ) - self.position_embedding.build([1, self.num_positions]) + self.position_embedding.build((1, self.num_positions)) self.position_ids = keras.ops.expand_dims( keras.ops.arange(self.num_positions), axis=0 ) @@ -123,8 +127,7 @@ def call(self, inputs): patch_embeddings = ops.reshape( patch_embeddings, [input_shape[0], -1, input_shape[-1]] ) - class_token = ops.expand_dims(self.class_token, axis=(0, 1)) - class_token = ops.tile(class_token, (input_shape[0], 1, 1)) + class_token = ops.tile(self.class_token, (input_shape[0], 1, 1)) position_embeddings = self.position_embedding(self.position_ids) embeddings = ops.concatenate([class_token, patch_embeddings], axis=1) return ops.add(embeddings, position_embeddings) @@ -272,7 +275,7 @@ def __init__( self.layer_norm_epsilon = layer_norm_epsilon def build(self, input_shape): - layers = [] + self.encoder_layers = keras.Sequential(name="encoder_layers") for i in range(self.num_layers): encoder_block = ViTEncoderBlock( num_heads=self.num_heads, @@ -287,9 +290,8 @@ def build(self, input_shape): name=f"tranformer_block_{i+1}", ) encoder_block.build((None, None, self.hidden_dim)) - layers.append(encoder_block) + self.encoder_layers.add(encoder_block) self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") - self.encoder_layers = keras.Sequential(layers, name="encoder_layers") self.layer_norm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, diff --git a/tools/checkpoint_conversion/convert_vit_checkpoints.py b/tools/checkpoint_conversion/convert_vit_checkpoints.py new file mode 100644 index 0000000000..109f802123 --- /dev/null +++ b/tools/checkpoint_conversion/convert_vit_checkpoints.py @@ -0,0 +1,321 @@ +"""Convert ViT checkpoints. + +export KAGGLE_USERNAME=XXX +export KAGGLE_KEY=XXX + +python tools/checkpoint_conversion/convert_vit_checkpoints.py \ + --preset vit_base_patch16_224 +""" + +import os +import shutil + +import keras +import numpy as np +import torch +from absl import app +from absl import flags +from PIL import Image +from transformers import ViTForImageClassification +from transformers import ViTImageProcessor + +import keras_hub +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier +from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( + ViTImageClassifierPreprocessor, +) +from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter + +FLAGS = flags.FLAGS + +PRESET_MAP = { + "vit_base_patch16_224": "google/vit-base-patch16-224", + "vit_base_patch16_384": "google/vit-base-patch16-384", + "vit_large_patch16_224": "google/vit-large-patch16-224", + "vit_large_patch16_384": "google/vit-large-patch16-384", +} + +flags.DEFINE_string( + "preset", + None, + f'Must be one of {",".join(PRESET_MAP.keys())}', + required=True, +) +flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/{variant}/keras/{preset}"', + required=False, +) + +flags.DEFINE_string( + "backbone_conversion_only", + False, + "Set to `True` when you want to convert only backbone when classification " + "head weights are not available", +) + + +def convert_model(hf_model): + config = hf_model.config.to_dict() + image_size = config["image_size"] + backbone = ViTBackbone( + image_shape=(image_size, image_size, 3), + patch_size=config["patch_size"], + num_layers=config["num_hidden_layers"], + num_heads=config["num_heads"], + hidden_dim=config["hidden_size"], + mlp_dim=config["intermediate_size"], + dropout_rate=config["hidden_dropout_prob"], + attention_dropout=config["attention_probs_dropout_prob"], + use_mha_bias=config["qkv_bias"], + ) + if FLAGS.backbone_conversion_only: + return backbone + + return ViTImageClassifier( + backbone=backbone, + num_classes=1000, # num classes in ImageNet + ) + + +def convert_weights(keras_hub_model, hf_model): + state_dict = hf_model.state_dict() + state_dict.update(hf_model.named_buffers()) + + # Helper functions. + def port_weights(keras_variable, weight_key, hook_fn=None): + torch_tensor = state_dict[weight_key].cpu().numpy() + if hook_fn: + torch_tensor = hook_fn(torch_tensor, list(keras_variable.shape)) + keras_variable.assign(torch_tensor) + + def port_ln(keras_variable, weight_key): + port_weights(keras_variable.gamma, f"{weight_key}.weight") + port_weights(keras_variable.beta, f"{weight_key}.bias") + + def port_dense(keras_variable, weight_key): + port_weights( + keras_variable.kernel, + f"{weight_key}.weight", + hook_fn=lambda x, _: x.T, + ) + if keras_variable.bias is not None: + port_weights(keras_variable.bias, f"{weight_key}.bias") + + def port_mha(keras_variable, weight_key, num_heads, hidden_dim): + # query + port_weights( + keras_variable.query_dense.kernel, + f"{weight_key}.attention.query.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (hidden_dim, num_heads, hidden_dim // num_heads) + ), + ) + port_weights( + keras_variable.query_dense.bias, + f"{weight_key}.attention.query.bias", + hook_fn=lambda x, _: np.reshape( + x, (num_heads, hidden_dim // num_heads) + ), + ) + # key + port_weights( + keras_variable.key_dense.kernel, + f"{weight_key}.attention.key.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (hidden_dim, num_heads, hidden_dim // num_heads) + ), + ) + port_weights( + keras_variable.key_dense.bias, + f"{weight_key}.attention.key.bias", + hook_fn=lambda x, _: np.reshape( + x, (num_heads, hidden_dim // num_heads) + ), + ) + # value + port_weights( + keras_variable.value_dense.kernel, + f"{weight_key}.attention.value.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (hidden_dim, num_heads, hidden_dim // num_heads) + ), + ) + port_weights( + keras_variable.value_dense.bias, + f"{weight_key}.attention.value.bias", + hook_fn=lambda x, _: np.reshape( + x, (num_heads, hidden_dim // num_heads) + ), + ) + # output + port_weights( + keras_variable.output_dense.kernel, + f"{weight_key}.output.dense.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (num_heads, hidden_dim // num_heads, hidden_dim) + ), + ) + port_weights( + keras_variable.output_dense.bias, f"{weight_key}.output.dense.bias" + ) + + port_weights( + keras_hub_model.backbone.layers[1].patch_embedding.kernel, + "vit.embeddings.patch_embeddings.projection.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + + port_weights( + keras_hub_model.backbone.layers[1].patch_embedding.bias, + "vit.embeddings.patch_embeddings.projection.bias", + ) + + port_weights( + keras_hub_model.backbone.layers[1].class_token, + "vit.embeddings.cls_token", + ) + + port_weights( + keras_hub_model.backbone.layers[1].position_embedding.embeddings, + "vit.embeddings.position_embeddings", + hook_fn=lambda x, _: x[0], + ) + encoder_layers = keras_hub_model.backbone.layers[2].encoder_layers + for i, encoder_block in enumerate(encoder_layers): + prefix = "vit.encoder.layer" + num_heads = encoder_block.num_heads + hidden_dim = encoder_block.hidden_dim + + port_mha( + encoder_block.mha, + f"{prefix}.{i}.attention", + num_heads, + hidden_dim, + ) + port_ln(encoder_block.layer_norm_1, f"{prefix}.{i}.layernorm_before") + port_ln(encoder_block.layer_norm_2, f"{prefix}.{i}.layernorm_after") + + port_dense( + encoder_block.mlp.dense_1, f"{prefix}.{i}.intermediate.dense" + ) + port_dense(encoder_block.mlp.dense_2, f"{prefix}.{i}.output.dense") + + port_ln(keras_hub_model.backbone.layers[2].layer_norm, "vit.layernorm") + if not FLAGS.backbone_conversion_only: + port_dense(keras_hub_model.output_dense, "classifier") + + +def convert_image_converter(hf_image_processor): + config = hf_image_processor.to_dict() + image_size = (config["size"]["height"], config["size"]["width"]) + std = config["image_std"] + mean = config["image_mean"] + return ViTImageConverter( + image_size=image_size, + scale=config["rescale_factor"], + norm_mean=mean, + norm_std=std, + interpolation="bilinear", # ViT defaults to bilinear resampling. + ) + + +def validate_output( + keras_model, + keras_image_converter, + hf_model, + hf_image_processor, +): + file = keras.utils.get_file( + origin=("http://images.cocodataset.org/val2017/000000039769.jpg") + ) + image = Image.open(file) + + # Preprocess with hf. + hf_inputs = hf_image_processor( + image, + return_tensors="pt", + ) + hf_preprocessed = hf_inputs["pixel_values"].detach().cpu().numpy() + + # Preprocess with keras. + images = np.expand_dims(np.array(image).astype("float32"), axis=0) + images = np.concatenate([images, images], axis=0) + images = keras_image_converter(images) + keras_preprocessed = keras.ops.convert_to_numpy(images) + + # Call with hf. Use the keras preprocessed image so we can keep modeling + # and preprocessing comparisons independent. + hf_inputs["pixel_values"] = torch.from_numpy( + keras.ops.convert_to_numpy( + keras.ops.transpose(keras_preprocessed, (0, 3, 1, 2)) + ) + ) + hf_outputs = hf_model(**hf_inputs) + hf_vision_logits = hf_outputs.logits.detach().cpu().numpy() + + # Call with keras. + keras_outputs = keras_model(keras_preprocessed) + keras_vision_logits = keras.ops.convert_to_numpy(keras_outputs) + + print("🔶 Keras output:", keras_vision_logits[0, :10]) + print("🔶 HF output:", hf_vision_logits[0, :10]) + modeling_diff = np.mean(np.abs(keras_vision_logits - hf_vision_logits)) + print("🔶 Modeling difference:", modeling_diff) + preprocessing_diff = np.mean( + np.abs(keras_preprocessed - np.transpose(hf_preprocessed, (0, 2, 3, 1))) + ) + print("🔶 Preprocessing difference:", preprocessing_diff) + + +def main(_): + if FLAGS.preset not in PRESET_MAP.keys(): + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one " + f"of {','.join(PRESET_MAP.keys())}" + ) + preset = FLAGS.preset + hf_preset = PRESET_MAP[preset] + if os.path.exists(preset): + shutil.rmtree(preset) + os.makedirs(preset) + + print(f"🏃 Coverting {preset}") + + # Load huggingface model. + hf_model = ViTForImageClassification.from_pretrained(hf_preset) + hf_preprocessor = ViTImageProcessor.from_pretrained(hf_preset) + hf_model.eval() + + keras_model = convert_model(hf_model) + keras_image_converter = convert_image_converter(hf_preprocessor) + keras_image_preprocessor = ViTImageClassifierPreprocessor( + image_converter=keras_image_converter + ) + print("✅ KerasHub model loaded.") + + convert_weights(keras_model, hf_model) + print("✅ Weights converted.") + + validate_output( + keras_model, + keras_image_converter, + hf_model, + hf_preprocessor, + ) + print("✅ Output validated.") + + keras_model.save_to_preset(f"./{preset}") + keras_image_preprocessor.save_to_preset(f"./{preset}") + print(f"🏁 Preset saved to ./{preset}.") + + upload_uri = FLAGS.upload_uri + if upload_uri: + keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"🏁 Preset uploaded to {upload_uri}") + + +if __name__ == "__main__": + app.run(main) From 95e58681401f28463b7876d98a36eec9ad24c2dd Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 21 Nov 2024 10:38:46 -0800 Subject: [PATCH 20/33] correct flag type --- tools/checkpoint_conversion/convert_vit_checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/checkpoint_conversion/convert_vit_checkpoints.py b/tools/checkpoint_conversion/convert_vit_checkpoints.py index 109f802123..393bad9569 100644 --- a/tools/checkpoint_conversion/convert_vit_checkpoints.py +++ b/tools/checkpoint_conversion/convert_vit_checkpoints.py @@ -49,7 +49,7 @@ required=False, ) -flags.DEFINE_string( +flags.DEFINE_bool( "backbone_conversion_only", False, "Set to `True` when you want to convert only backbone when classification " From 9d2e5bdd73699eb7a957ddd6940e4abc040044d7 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 21 Nov 2024 10:40:51 -0800 Subject: [PATCH 21/33] correct key name --- tools/checkpoint_conversion/convert_vit_checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/checkpoint_conversion/convert_vit_checkpoints.py b/tools/checkpoint_conversion/convert_vit_checkpoints.py index 393bad9569..d7282d2dfc 100644 --- a/tools/checkpoint_conversion/convert_vit_checkpoints.py +++ b/tools/checkpoint_conversion/convert_vit_checkpoints.py @@ -64,7 +64,7 @@ def convert_model(hf_model): image_shape=(image_size, image_size, 3), patch_size=config["patch_size"], num_layers=config["num_hidden_layers"], - num_heads=config["num_heads"], + num_heads=config["num_attention_heads"], hidden_dim=config["hidden_size"], mlp_dim=config["intermediate_size"], dropout_rate=config["hidden_dropout_prob"], From ac7d1d3d1f830fd90127d179210ddfd5bf90741c Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 21 Nov 2024 10:45:05 -0800 Subject: [PATCH 22/33] use flat list later we can extract in between layers if needed --- keras_hub/src/models/vit/vit_layers.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index eef58ca49e..449dca0f2e 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -275,7 +275,7 @@ def __init__( self.layer_norm_epsilon = layer_norm_epsilon def build(self, input_shape): - self.encoder_layers = keras.Sequential(name="encoder_layers") + self.encoder_layers = [] for i in range(self.num_layers): encoder_block = ViTEncoderBlock( num_heads=self.num_heads, @@ -290,7 +290,7 @@ def build(self, input_shape): name=f"tranformer_block_{i+1}", ) encoder_block.build((None, None, self.hidden_dim)) - self.encoder_layers.add(encoder_block) + self.encoder_layers.append(encoder_block) self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") self.layer_norm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, @@ -302,7 +302,8 @@ def build(self, input_shape): def call(self, inputs): x = self.dropout(inputs) - x = self.encoder_layers(x) + for i in range(self.num_layers): + x = self.encoder_layers[i](x) x = self.layer_norm(x) return x From 8065c01b293545598edda709f44c6df4a8729145 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 21 Nov 2024 14:58:29 -0800 Subject: [PATCH 23/33] Add test cases and correct dtype polciy for model --- keras_hub/src/models/vit/vit_backbone.py | 1 + keras_hub/src/models/vit/vit_backbone_test.py | 36 ++++++++++++ .../models/vit/vit_image_classifier_test.py | 57 +++++++++++++++++++ keras_hub/src/models/vit/vit_layers.py | 12 +++- .../convert_vit_checkpoints.py | 16 +----- 5 files changed, 106 insertions(+), 16 deletions(-) create mode 100644 keras_hub/src/models/vit/vit_backbone_test.py create mode 100644 keras_hub/src/models/vit/vit_image_classifier_test.py diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index 027be5aa28..fa6b18ec74 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -76,6 +76,7 @@ def __init__( super().__init__( inputs=inputs, outputs=output, + dtype=dtype, **kwargs, ) diff --git a/keras_hub/src/models/vit/vit_backbone_test.py b/keras_hub/src/models/vit/vit_backbone_test.py new file mode 100644 index 0000000000..9a0368402e --- /dev/null +++ b/keras_hub/src/models/vit/vit_backbone_test.py @@ -0,0 +1,36 @@ +import pytest +from keras import ops + +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.tests.test_case import TestCase + + +class ViTBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "image_shape": (28, 28, 3), + "patch_size": 4, + "num_layers": 3, + "hidden_dim": 48, + "num_heads": 6, + "mlp_dim": 48 * 4, + "use_mha_bias": True, + } + self.input_size = 28 + self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) + + def test_backbone_basics(self): + self.run_vision_backbone_test( + cls=ViTBackbone, + init_kwargs={**self.init_kwargs}, + input_data=self.input_data, + expected_output_shape=(2, 50, 48), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=ViTBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/vit/vit_image_classifier_test.py b/keras_hub/src/models/vit/vit_image_classifier_test.py new file mode 100644 index 0000000000..a2e6085945 --- /dev/null +++ b/keras_hub/src/models/vit/vit_image_classifier_test.py @@ -0,0 +1,57 @@ +import pytest +from keras import ops + +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier +from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( + ViTImageClassifierPreprocessor, +) +from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter +from keras_hub.src.tests.test_case import TestCase + + +class ViTImageClassifierTest(TestCase): + def setUp(self): + self.images = ops.ones((2, 28, 28, 3)) + self.labels = [0, 1] + self.backbone = ViTBackbone( + image_shape=(28, 28, 3), + patch_size=4, + num_layers=3, + num_heads=6, + hidden_dim=48, + mlp_dim=48 * 4, + ) + image_converter = ViTImageConverter( + image_size=(28, 28), + scale=1 / 255.0, + ) + preprocessor = ViTImageClassifierPreprocessor( + image_converter=image_converter + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "preprocessor": preprocessor, + } + self.train_data = (self.images, self.labels) + + def test_classifier_basics(self): + self.run_task_test( + cls=ViTImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + def test_head_dtype(self): + model = ViTImageClassifier(**self.init_kwargs, head_dtype="bfloat16") + self.assertEqual(model.output_dense.compute_dtype, "bfloat16") + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=ViTImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 449dca0f2e..36a4b9e8bb 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -47,7 +47,9 @@ def build(self, input_shape): name="dense_2", ) self.dense_2.build((None, None, self.mlp_dim)) - self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") + self.dropout = keras.layers.Dropout( + self.dropout_rate, dtype=self.dtype_policy, name="dropout" + ) self.built = True def call(self, inputs): @@ -199,7 +201,9 @@ def build(self, input_shape): dtype=self.dtype_policy, ) self.mha.build(input_shape, input_shape) - self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") + self.dropout = keras.layers.Dropout( + self.dropout_rate, dtype=self.dtype_policy, name="dropout" + ) # MLP block self.layer_norm_2 = keras.layers.LayerNormalization( @@ -291,7 +295,9 @@ def build(self, input_shape): ) encoder_block.build((None, None, self.hidden_dim)) self.encoder_layers.append(encoder_block) - self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") + self.dropout = keras.layers.Dropout( + self.dropout_rate, dtype=self.dtype_policy, name="dropout" + ) self.layer_norm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, diff --git a/tools/checkpoint_conversion/convert_vit_checkpoints.py b/tools/checkpoint_conversion/convert_vit_checkpoints.py index d7282d2dfc..4868a7b040 100644 --- a/tools/checkpoint_conversion/convert_vit_checkpoints.py +++ b/tools/checkpoint_conversion/convert_vit_checkpoints.py @@ -49,13 +49,6 @@ required=False, ) -flags.DEFINE_bool( - "backbone_conversion_only", - False, - "Set to `True` when you want to convert only backbone when classification " - "head weights are not available", -) - def convert_model(hf_model): config = hf_model.config.to_dict() @@ -71,8 +64,6 @@ def convert_model(hf_model): attention_dropout=config["attention_probs_dropout_prob"], use_mha_bias=config["qkv_bias"], ) - if FLAGS.backbone_conversion_only: - return backbone return ViTImageClassifier( backbone=backbone, @@ -204,8 +195,7 @@ def port_mha(keras_variable, weight_key, num_heads, hidden_dim): port_dense(encoder_block.mlp.dense_2, f"{prefix}.{i}.output.dense") port_ln(keras_hub_model.backbone.layers[2].layer_norm, "vit.layernorm") - if not FLAGS.backbone_conversion_only: - port_dense(keras_hub_model.output_dense, "classifier") + port_dense(keras_hub_model.output_dense, "classifier") def convert_image_converter(hf_image_processor): @@ -306,9 +296,9 @@ def main(_): hf_preprocessor, ) print("✅ Output validated.") - + keras_model.preprocessor = keras_image_preprocessor keras_model.save_to_preset(f"./{preset}") - keras_image_preprocessor.save_to_preset(f"./{preset}") + print(f"🏁 Preset saved to ./{preset}.") upload_uri = FLAGS.upload_uri From a8be82408f26c2dfff7d9964084cbde9bcac8571 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 21 Nov 2024 15:46:26 -0800 Subject: [PATCH 24/33] add proper docstrings --- keras_hub/src/models/vit/vit_backbone.py | 36 +++++++++++ .../src/models/vit/vit_image_converter.py | 36 +++++++++++ keras_hub/src/models/vit/vit_layers.py | 62 +++++++++++++++++++ 3 files changed, 134 insertions(+) diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index fa6b18ec74..8fe3d25866 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -9,6 +9,42 @@ @keras_hub_export("keras_hub.models.ViTBackbone") class ViTBackbone(Backbone): + """Vision Transformer (ViT) backbone. + + This backbone implements the Vision Transformer architecture as described in + [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). + It transforms the input image into a sequence of patches, embeds them, and + then processes them through a series of Transformer encoder layers. + + Args: + image_shape: A tuple or list of 3 integers representing the shape of the + input image `(height, width, channels)`, `height` and `width` must + be equal. + patch_size: int. The size of each image patch, the input image will be + divided into patches of shape `(patch_size, patch_size)`. + num_layers: int. The number of transformer encoder layers. + num_heads: int. specifying the number of attention heads in each + Transformer encoder layer. + hidden_dim: int. The dimensionality of the hidden representations. + mlp_dim: int. The dimensionality of the intermediate MLP layer in + each Transformer encoder layer. + dropout_rate: float. The dropout rate for the Transformer encoder + layers. + attention_dropout: float. The dropout rate for the attention mechanism + in each Transformer encoder layer. + layer_norm_epsilon: float. Value used for numerical stability in + layer normalization. + use_mha_bias: bool. Whether to use bias in the multi-head + attention layers. + use_mlp_bias: bool. Whether to use bias in the MLP layers. + data_format: str. `"channels_last"` or `"channels_first"`, specifying + the data format for the input image. If `None`, defaults to + `"channels_last"`. + dtype: The dtype of the layer weights. Defaults to None. + **kwargs: Additional keyword arguments to be passed to the parent + `Backbone` class. + """ + def __init__( self, image_shape, diff --git a/keras_hub/src/models/vit/vit_image_converter.py b/keras_hub/src/models/vit/vit_image_converter.py index 705c8a8b46..b1699640ce 100644 --- a/keras_hub/src/models/vit/vit_image_converter.py +++ b/keras_hub/src/models/vit/vit_image_converter.py @@ -6,6 +6,42 @@ @keras_hub_export("keras_hub.layers.ViTImageConverter") class ViTImageConverter(ImageConverter): + """Converts images to the format expected by a ViT model. + + This layer performs image normalization using mean and standard deviation values. + By default, it uses the same normalization as the + "google/vit-large-patch16-224" model on Hugging Face: + `norm_mean=[0.5, 0.5, 0.5]` and `norm_std=[0.5, 0.5, 0.5]` + ([reference](https://huggingface.co/google/vit-large-patch16-224/blob/main/preprocessor_config.json)). + These defaults are suitable for models pretrained using this normalization. + + Args: + norm_mean: list or tuple of floats. Mean values for image normalization. + Defaults to `[0.5, 0.5, 0.5]`. + norm_std: list or tuple of floats. Standard deviation values for + image normalization. Defaults to `[0.5, 0.5, 0.5]`. + **kwargs: Additional keyword arguments passed to + `keras_hub.layers.preprocessing.ImageConverter`. + + Examples: + ```python + import keras + import numpy as np + from keras_hub.src.layers import ViTImageConverter + + # Example image (replace with your actual image data) + image = np.random.rand(1, 224, 224, 3) # Example: (B, H, W, C) + + # Create a ViTImageConverter instance + converter = ViTImageConverter( + image_size=(28,28), + scale=1/255. + ) + # Preprocess the image + preprocessed_image = converter(image) + ``` + """ + backbone_cls = ViTBackbone def __init__( diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 36a4b9e8bb..cec21e7511 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -5,6 +5,17 @@ class MLP(keras.layers.Layer): + """Multi-Layer Perceptron (MLP) block. + + Args: + hidden_dim: int. Dimensionality of the hidden representations. + mlp_dim: int. Dimensionality of the intermediate MLP layer. + use_bias: bool. Whether to use bias in the dense layers. Defaults to + `True`. + dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`. + **kwargs: Additional keyword arguments passed to `keras.layers.Layer` + """ + def __init__( self, hidden_dim, @@ -60,6 +71,20 @@ def call(self, inputs): class ViTPatchingAndEmbedding(keras.layers.Layer): + """Patches the image and embeds the patches. + + Args: + image_size: int. Size of the input image (height or width). + Assumed to be square. + patch_size: int. Size of each image patch. + hidden_dim: int. Dimensionality of the patch embeddings. + num_channels: int. Number of channels in the input image. Defaults to + `3`. + data_format: str. `"channels_last"` or `"channels_first"`. Defaults to + `None` (which uses `"channels_last"`). + **kwargs: Additional keyword arguments passed to `keras.layers.Layer` + """ + def __init__( self, image_size, @@ -157,6 +182,24 @@ def get_config(self): class ViTEncoderBlock(keras.layers.Layer): + """Transformer encoder block. + + Args: + num_heads: int. Number of attention heads. + hidden_dim: int. Dimensionality of the hidden representations. + mlp_dim: int. Dimensionality of the intermediate MLP layer. + use_mha_bias: bool. Whether to use bias in the multi-head attention + layer. Defaults to `True`. + use_mlp_bias: bool. Whether to use bias in the MLP layer. Defaults to + `True`. + dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`. + attention_dropout: float. Dropout rate for the attention mechanism. + Between 0 and 1. Defaults to `0.0`. + layer_norm_epsilon: float. Small float value for layer normalization + stability. Defaults to `1e-6`. + **kwargs: Additional keyword arguments passed to `keras.layers.Layer` + """ + def __init__( self, num_heads, @@ -252,6 +295,25 @@ def get_config(self): class ViTEncoder(keras.layers.Layer): + """Vision Transformer (ViT) encoder. + + Args: + num_layers: int. Number of Transformer encoder blocks. + num_heads: int. Number of attention heads. + hidden_dim: int. Dimensionality of the hidden representations. + mlp_dim: int. Dimensionality of the intermediate MLP layer. + use_mha_bias: bool. Whether to use bias in the multi-head attention + layers. Defaults to `True`. + use_mlp_bias: bool. Whether to use bias in the MLP layers. Defaults to + `True`. + dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`. + attention_dropout: float. Dropout rate for the attention mechanism. + Between 0 and 1. Defaults to `0.0`. + layer_norm_epsilon: float. Small float value for layer normalization + tability. Defaults to `1e-6`. + **kwargs: Additional keyword arguments passed to `keras.layers.Layer` + """ + def __init__( self, num_layers, From 3f027a0ae6b5f8d5abf32acdb3301c7dd4d3c265 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 22 Nov 2024 14:14:15 -0800 Subject: [PATCH 25/33] correct test cases --- keras_hub/src/models/vit/vit_backbone.py | 1 + keras_hub/src/models/vit/vit_backbone_test.py | 3 ++- keras_hub/src/models/vit/vit_layers.py | 8 +++----- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index 8fe3d25866..d044f8def6 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -91,6 +91,7 @@ def __init__( patch_size=patch_size, hidden_dim=hidden_dim, num_channels=num_channels, + data_format=data_format, dtype=dtype, name="vit_patching_and_embedding", )(inputs) diff --git a/keras_hub/src/models/vit/vit_backbone_test.py b/keras_hub/src/models/vit/vit_backbone_test.py index 9a0368402e..0ab0b389ca 100644 --- a/keras_hub/src/models/vit/vit_backbone_test.py +++ b/keras_hub/src/models/vit/vit_backbone_test.py @@ -20,11 +20,12 @@ def setUp(self): self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) def test_backbone_basics(self): - self.run_vision_backbone_test( + self.run_backbone_test( cls=ViTBackbone, init_kwargs={**self.init_kwargs}, input_data=self.input_data, expected_output_shape=(2, 50, 48), + run_quantization_check=False, ) @pytest.mark.large diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index cec21e7511..8cdc52ca71 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -144,17 +144,15 @@ def build(self, input_shape): def call(self, inputs): patch_embeddings = self.patch_embedding(inputs) - input_shape = ops.shape( - patch_embeddings - ) # (N, H, W, C) or (N, C, H, W) if self.data_format == "channels_first": patch_embeddings = ops.transpose( patch_embeddings, axes=(0, 2, 3, 1) ) + embeddings_shape = ops.shape(patch_embeddings) patch_embeddings = ops.reshape( - patch_embeddings, [input_shape[0], -1, input_shape[-1]] + patch_embeddings, [embeddings_shape[0], -1, embeddings_shape[-1]] ) - class_token = ops.tile(self.class_token, (input_shape[0], 1, 1)) + class_token = ops.tile(self.class_token, (embeddings_shape[0], 1, 1)) position_embeddings = self.position_embedding(self.position_ids) embeddings = ops.concatenate([class_token, patch_embeddings], axis=1) return ops.add(embeddings, position_embeddings) From 05acb706a3fee3778442f9810d305449d68830f0 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 25 Nov 2024 13:48:04 -0800 Subject: [PATCH 26/33] use numpy for test data --- keras_hub/src/models/vit/vit_image_classifier_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/models/vit/vit_image_classifier_test.py b/keras_hub/src/models/vit/vit_image_classifier_test.py index a2e6085945..b50e511962 100644 --- a/keras_hub/src/models/vit/vit_image_classifier_test.py +++ b/keras_hub/src/models/vit/vit_image_classifier_test.py @@ -1,5 +1,5 @@ import pytest -from keras import ops +import numpy as np from keras_hub.src.models.vit.vit_backbone import ViTBackbone from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier @@ -12,7 +12,7 @@ class ViTImageClassifierTest(TestCase): def setUp(self): - self.images = ops.ones((2, 28, 28, 3)) + self.images = np.ones((2, 28, 28, 3)) self.labels = [0, 1] self.backbone = ViTBackbone( image_shape=(28, 28, 3), From 521df6fb3e8248954c530d864f41460faea6f3d7 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 25 Nov 2024 13:55:10 -0800 Subject: [PATCH 27/33] nit --- keras_hub/src/models/vit/vit_image_classifier_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/vit/vit_image_classifier_test.py b/keras_hub/src/models/vit/vit_image_classifier_test.py index b50e511962..29e3d66922 100644 --- a/keras_hub/src/models/vit/vit_image_classifier_test.py +++ b/keras_hub/src/models/vit/vit_image_classifier_test.py @@ -1,5 +1,5 @@ -import pytest import numpy as np +import pytest from keras_hub.src.models.vit.vit_backbone import ViTBackbone from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier From ae2b800fac4d1366f9d635bf66939509a009291f Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 27 Nov 2024 12:20:13 -0800 Subject: [PATCH 28/33] nit --- keras_hub/src/models/vit/vit_backbone.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index d044f8def6..c34ab7d498 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -62,6 +62,7 @@ def __init__( dtype=None, **kwargs, ): + # === Laters === data_format = standardize_data_format(data_format) h_axis, w_axis, channels_axis = ( (-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3) @@ -127,6 +128,8 @@ def __init__( self.dropout_rate = dropout_rate self.attention_dropout = attention_dropout self.layer_norm_epsilon = layer_norm_epsilon + self.use_mha_bias = use_mha_bias + self.use_mlp_bias = use_mlp_bias self.data_format = data_format def get_config(self): @@ -142,6 +145,8 @@ def get_config(self): "dropout_rate": self.dropout_rate, "attention_dropout": self.attention_dropout, "layer_norm_epsilon": self.layer_norm_epsilon, + "use_mha_bias": self.use_mha_bias, + "use_mlp_bias": self.use_mlp_bias, } ) return config From 92149d5da76b33376718ab7963df60d9b582c062 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 2 Dec 2024 13:57:30 -0800 Subject: [PATCH 29/33] add presets --- keras_hub/src/models/vit/__init__.py | 5 +++ keras_hub/src/models/vit/vit_presets.py | 57 +++++++++++++++++++++++++ 2 files changed, 62 insertions(+) create mode 100644 keras_hub/src/models/vit/vit_presets.py diff --git a/keras_hub/src/models/vit/__init__.py b/keras_hub/src/models/vit/__init__.py index e69de29bb2..e4b42de07d 100644 --- a/keras_hub/src/models/vit/__init__.py +++ b/keras_hub/src/models/vit/__init__.py @@ -0,0 +1,5 @@ +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, ViTBackbone) diff --git a/keras_hub/src/models/vit/vit_presets.py b/keras_hub/src/models/vit/vit_presets.py new file mode 100644 index 0000000000..445372bd40 --- /dev/null +++ b/keras_hub/src/models/vit/vit_presets.py @@ -0,0 +1,57 @@ +"""ViT model preset configurations.""" + +# Metadata for loading pretrained model weights. +backbone_presets = { + "vit_base_patch16_224_imagenet": { + "metadata": { + "description": ( + "ViT-B16 model pre-trained on the ImageNet 1k dataset with " + "image resolution of 224x224 " + ), + "params": 85798656, + "official_name": "ViT", + "path": "vit", + "model_card": "https://www.kaggle.com/models/keras/vit", + }, + "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet/1", + }, + "vit_base_patch16_384_imagenet": { + "metadata": { + "description": ( + "ViT-B16 model pre-trained on the ImageNet 1k dataset with " + "image resolution of 384x384 " + ), + "params": 86090496, + "official_name": "ViT", + "path": "vit", + "model_card": "https://www.kaggle.com/models/keras/vit", + }, + "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_384_imagenet/1", + }, + "vit_large_patch16_224_imagenet": { + "metadata": { + "description": ( + "ViT-L16 model pre-trained on the ImageNet 1k dataset with " + "image resolution of 224x224 " + ), + "params": 303301632, + "official_name": "ViT", + "path": "vit", + "model_card": "https://www.kaggle.com/models/keras/vit", + }, + "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet/1", + }, + "vit_large_patch16_384_imagenet": { + "metadata": { + "description": ( + "ViT-L16 model pre-trained on the ImageNet 1k dataset with " + "image resolution of 384x384 " + ), + "params": 303690752, + "official_name": "ViT", + "path": "vit", + "model_card": "https://www.kaggle.com/models/keras/vit", + }, + "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_384_imagenet/1", + }, +} From 72dc592a760dff6c6863313430c5cd0a6eb10948 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 11 Dec 2024 16:43:53 -0800 Subject: [PATCH 30/33] move nms, anchor_generator and box_matcher to modeling layers --- keras_hub/api/layers/__init__.py | 2 +- .../modeling}/anchor_generator.py | 13 ++--- .../modeling}/anchor_generator_test.py | 2 +- .../modeling}/box_matcher.py | 3 + .../modeling}/box_matcher_test.py | 2 +- .../modeling}/non_max_supression.py | 58 ++++++++++--------- .../modeling}/non_max_supression_test.py | 6 +- .../retinanet/retinanet_label_encoder.py | 2 +- .../retinanet/retinanet_object_detector.py | 4 +- 9 files changed, 50 insertions(+), 42 deletions(-) rename keras_hub/src/{models/retinanet => layers/modeling}/anchor_generator.py (95%) rename keras_hub/src/{models/retinanet => layers/modeling}/anchor_generator_test.py (96%) rename keras_hub/src/{models/retinanet => layers/modeling}/box_matcher.py (99%) rename keras_hub/src/{models/retinanet => layers/modeling}/box_matcher_test.py (98%) rename keras_hub/src/{models/retinanet => layers/modeling}/non_max_supression.py (94%) rename keras_hub/src/{models/retinanet => layers/modeling}/non_max_supression_test.py (88%) diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 6f44e0ca08..d2e07336f5 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -5,6 +5,7 @@ """ from keras_hub.src.layers.modeling.alibi_bias import AlibiBias +from keras_hub.src.layers.modeling.anchor_generator import AnchorGenerator from keras_hub.src.layers.modeling.cached_multi_head_attention import ( CachedMultiHeadAttention, ) @@ -52,7 +53,6 @@ from keras_hub.src.models.resnet.resnet_image_converter import ( ResNetImageConverter, ) -from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator from keras_hub.src.models.retinanet.retinanet_image_converter import ( RetinaNetImageConverter, ) diff --git a/keras_hub/src/models/retinanet/anchor_generator.py b/keras_hub/src/layers/modeling/anchor_generator.py similarity index 95% rename from keras_hub/src/models/retinanet/anchor_generator.py rename to keras_hub/src/layers/modeling/anchor_generator.py index a3c3800c49..2b467f72ff 100644 --- a/keras_hub/src/models/retinanet/anchor_generator.py +++ b/keras_hub/src/layers/modeling/anchor_generator.py @@ -5,9 +5,6 @@ from keras_hub.src.api_export import keras_hub_export -# TODO: https://github.com/keras-team/keras-hub/issues/1965 -from keras_hub.src.bounding_box.converters import convert_format - @keras_hub_export("keras_hub.layers.AnchorGenerator") class AnchorGenerator(keras.layers.Layer): @@ -133,10 +130,12 @@ def call(self, inputs): anchors = shifts + base_anchors anchors = ops.reshape(anchors, (-1, 4)) - multilevel_anchors[f"P{level}"] = convert_format( - anchors, - source="xyxy", - target=self.bounding_box_format, + multilevel_anchors[f"P{level}"] = ( + keras.utils.bounding_boxes.convert_format( + anchors, + source="xyxy", + target=self.bounding_box_format, + ) ) return multilevel_anchors diff --git a/keras_hub/src/models/retinanet/anchor_generator_test.py b/keras_hub/src/layers/modeling/anchor_generator_test.py similarity index 96% rename from keras_hub/src/models/retinanet/anchor_generator_test.py rename to keras_hub/src/layers/modeling/anchor_generator_test.py index 0b71630843..f3bc2510de 100644 --- a/keras_hub/src/models/retinanet/anchor_generator_test.py +++ b/keras_hub/src/layers/modeling/anchor_generator_test.py @@ -2,7 +2,7 @@ from absl.testing import parameterized from keras import ops -from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator +from keras_hub.src.layers.modeling.anchor_generator import AnchorGenerator from keras_hub.src.tests.test_case import TestCase diff --git a/keras_hub/src/models/retinanet/box_matcher.py b/keras_hub/src/layers/modeling/box_matcher.py similarity index 99% rename from keras_hub/src/models/retinanet/box_matcher.py rename to keras_hub/src/layers/modeling/box_matcher.py index dd8a486814..b841e8deb5 100644 --- a/keras_hub/src/models/retinanet/box_matcher.py +++ b/keras_hub/src/layers/modeling/box_matcher.py @@ -1,7 +1,10 @@ import keras from keras import ops +from keras_hub.src.api_export import keras_hub_export + +@keras_hub_export("keras_hub.layers.BoxMatcher") class BoxMatcher(keras.layers.Layer): """Box matching logic based on argmax of highest value (e.g., IOU). diff --git a/keras_hub/src/models/retinanet/box_matcher_test.py b/keras_hub/src/layers/modeling/box_matcher_test.py similarity index 98% rename from keras_hub/src/models/retinanet/box_matcher_test.py rename to keras_hub/src/layers/modeling/box_matcher_test.py index d991f90e5b..5fdf39a7ac 100644 --- a/keras_hub/src/models/retinanet/box_matcher_test.py +++ b/keras_hub/src/layers/modeling/box_matcher_test.py @@ -1,7 +1,7 @@ import numpy as np from keras import ops -from keras_hub.src.models.retinanet.box_matcher import BoxMatcher +from keras_hub.src.layers.modeling.box_matcher import BoxMatcher from keras_hub.src.tests.test_case import TestCase diff --git a/keras_hub/src/models/retinanet/non_max_supression.py b/keras_hub/src/layers/modeling/non_max_supression.py similarity index 94% rename from keras_hub/src/models/retinanet/non_max_supression.py rename to keras_hub/src/layers/modeling/non_max_supression.py index 5ca52b4dfc..70595492e5 100644 --- a/keras_hub/src/models/retinanet/non_max_supression.py +++ b/keras_hub/src/layers/modeling/non_max_supression.py @@ -2,22 +2,22 @@ import keras from keras import ops +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes import ( + validation, +) -# TODO: https://github.com/keras-team/keras-hub/issues/1965 -from keras_hub.src.bounding_box import converters -from keras_hub.src.bounding_box import utils -from keras_hub.src.bounding_box import validate_format +from keras_hub.src.api_export import keras_hub_export EPSILON = 1e-8 +@keras_hub_export("keras_hub.layers.NonMaxSupression") class NonMaxSuppression(keras.layers.Layer): """A Keras layer that decodes predictions of an object detection model. Args: bounding_box_format: The format of bounding boxes of input dataset. - Refer - TODO: link keras core bounding box docs + Refer: for more details on supported bounding box formats. from_logits: boolean, True means input score is logits, False means confidence. @@ -49,7 +49,10 @@ def __init__( self.built = True def call( - self, box_prediction, class_prediction, images=None, image_shape=None + self, + box_prediction, + class_prediction, + images=None, ): """Accepts images and raw scores, returning bounding box predictions. @@ -59,15 +62,24 @@ def call( class_prediction: Dense Tensor of shape [batch, boxes, num_classes]. """ target_format = "yxyx" - if utils.is_relative(self.bounding_box_format): - target_format = utils.as_relative(target_format) + height, width = None, None + + if "rel" in self.bounding_box_format and images is None: + raise ValueError( + "`images` cannot be None when using relative " + "bounding box format." + ) + + if "rel" in self.bounding_box_format: + target_format = "rel_" + target_format + height, width, _ = ops.shape(images) - box_prediction = converters.convert_format( + box_prediction = keras.utils.bounding_boxes.convert_format( box_prediction, source=self.bounding_box_format, target=target_format, - images=images, - image_shape=image_shape, + height=height, + width=width, ) if self.from_logits: class_prediction = ops.sigmoid(class_prediction) @@ -95,17 +107,17 @@ def call( class_prediction, ops.expand_dims(idx, axis=-1), axis=1 ) - box_prediction = converters.convert_format( + box_prediction = keras.utils.bounding_boxes.convert_format( box_prediction, source=target_format, target=self.bounding_box_format, - images=images, - image_shape=image_shape, + height=height, + width=width, ) bounding_boxes = { "boxes": box_prediction, "confidence": confidence_prediction, - "classes": ops.argmax(class_prediction, axis=-1), + "labels": ops.argmax(class_prediction, axis=-1), "num_detections": valid_det, } @@ -519,14 +531,8 @@ def mask_invalid_detections(bounding_boxes): returned value will also return `tf.RaggedTensor` representations. """ # ensure we are complying with Keras bounding box format. - info = validate_format.validate_format(bounding_boxes) - if info["ragged"]: - raise ValueError( - "`bounding_box.mask_invalid_detections()` requires inputs to be " - "Dense tensors. Please call " - "`bounding_box.to_dense(bounding_boxes)` before passing your boxes " - "to `bounding_box.mask_invalid_detections()`." - ) + validation.validate_bounding_boxes(bounding_boxes) + if "num_detections" not in bounding_boxes: raise ValueError( "`bounding_boxes` must have key 'num_detections' " @@ -534,7 +540,7 @@ def mask_invalid_detections(bounding_boxes): ) boxes = bounding_boxes.get("boxes") - classes = bounding_boxes.get("classes") + classes = bounding_boxes.get("labels") confidence = bounding_boxes.get("confidence", None) num_detections = bounding_boxes.get("num_detections") @@ -558,7 +564,7 @@ def mask_invalid_detections(bounding_boxes): result = bounding_boxes.copy() result["boxes"] = boxes - result["classes"] = classes + result["labels"] = classes if confidence is not None: result["confidence"] = confidence diff --git a/keras_hub/src/models/retinanet/non_max_supression_test.py b/keras_hub/src/layers/modeling/non_max_supression_test.py similarity index 88% rename from keras_hub/src/models/retinanet/non_max_supression_test.py rename to keras_hub/src/layers/modeling/non_max_supression_test.py index 94d3c3f124..4f310a0934 100644 --- a/keras_hub/src/models/retinanet/non_max_supression_test.py +++ b/keras_hub/src/layers/modeling/non_max_supression_test.py @@ -1,7 +1,7 @@ import numpy as np from keras import ops -from keras_hub.src.models.retinanet.non_max_supression import NonMaxSuppression +from keras_hub.src.layers.modeling.non_max_supression import NonMaxSuppression from keras_hub.src.tests.test_case import TestCase @@ -29,7 +29,7 @@ def test_confidence_threshold(self): self.assertAllClose( outputs["boxes"], [boxes[0][-2:, ...], boxes[1][:2, ...]] ) - self.assertAllClose(outputs["classes"], [[0.0, 0.0], [0.0, 0.0]]) + self.assertAllClose(outputs["labels"], [[0.0, 0.0], [0.0, 0.0]]) self.assertAllClose(outputs["confidence"], [[0.9, 0.5], [0.7, 0.5]]) def test_max_detections(self): @@ -55,5 +55,5 @@ def test_max_detections(self): self.assertAllClose( outputs["boxes"], [boxes[0][-1:, ...], boxes[1][:1, ...]] ) - self.assertAllClose(outputs["classes"], [[0.0], [0.0]]) + self.assertAllClose(outputs["labels"], [[0.0], [0.0]]) self.assertAllClose(outputs["confidence"], [[0.9], [0.7]]) diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder.py b/keras_hub/src/models/retinanet/retinanet_label_encoder.py index 66a6ff6d78..bfb74c83da 100644 --- a/keras_hub/src/models/retinanet/retinanet_label_encoder.py +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder.py @@ -7,7 +7,7 @@ from keras_hub.src.bounding_box.converters import convert_format from keras_hub.src.bounding_box.converters import encode_box_to_deltas from keras_hub.src.bounding_box.iou import compute_iou -from keras_hub.src.models.retinanet.box_matcher import BoxMatcher +from keras_hub.src.layers.modeling.box_matcher import BoxMatcher from keras_hub.src.utils import tensor_utils diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index 95b4d754fe..ba343a8d0d 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -6,9 +6,9 @@ # TODO: https://github.com/keras-team/keras-hub/issues/1965 from keras_hub.src.bounding_box.converters import convert_format from keras_hub.src.bounding_box.converters import decode_deltas_to_boxes +from keras_hub.src.layers.modeling.anchor_generator import AnchorGenerator +from keras_hub.src.layers.modeling.non_max_supression import NonMaxSuppression from keras_hub.src.models.image_object_detector import ImageObjectDetector -from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator -from keras_hub.src.models.retinanet.non_max_supression import NonMaxSuppression from keras_hub.src.models.retinanet.prediction_head import PredictionHead from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone from keras_hub.src.models.retinanet.retinanet_label_encoder import ( From fa65d242c3ed2a567cedc6f9b378934029bcd4db Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 11 Dec 2024 16:43:53 -0800 Subject: [PATCH 31/33] move nms, anchor_generator and box_matcher to modeling layers --- keras_hub/api/layers/__init__.py | 2 +- .../modeling}/anchor_generator.py | 13 ++--- .../modeling}/anchor_generator_test.py | 2 +- .../modeling}/box_matcher.py | 3 + .../modeling}/box_matcher_test.py | 2 +- .../modeling}/non_max_supression.py | 58 ++++++++++--------- .../modeling}/non_max_supression_test.py | 6 +- .../retinanet/retinanet_label_encoder.py | 2 +- .../retinanet/retinanet_object_detector.py | 4 +- 9 files changed, 50 insertions(+), 42 deletions(-) rename keras_hub/src/{models/retinanet => layers/modeling}/anchor_generator.py (95%) rename keras_hub/src/{models/retinanet => layers/modeling}/anchor_generator_test.py (96%) rename keras_hub/src/{models/retinanet => layers/modeling}/box_matcher.py (99%) rename keras_hub/src/{models/retinanet => layers/modeling}/box_matcher_test.py (98%) rename keras_hub/src/{models/retinanet => layers/modeling}/non_max_supression.py (94%) rename keras_hub/src/{models/retinanet => layers/modeling}/non_max_supression_test.py (88%) diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index f389052f8e..cad2c47a4c 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -5,6 +5,7 @@ """ from keras_hub.src.layers.modeling.alibi_bias import AlibiBias +from keras_hub.src.layers.modeling.anchor_generator import AnchorGenerator from keras_hub.src.layers.modeling.cached_multi_head_attention import ( CachedMultiHeadAttention, ) @@ -52,7 +53,6 @@ from keras_hub.src.models.resnet.resnet_image_converter import ( ResNetImageConverter, ) -from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator from keras_hub.src.models.retinanet.retinanet_image_converter import ( RetinaNetImageConverter, ) diff --git a/keras_hub/src/models/retinanet/anchor_generator.py b/keras_hub/src/layers/modeling/anchor_generator.py similarity index 95% rename from keras_hub/src/models/retinanet/anchor_generator.py rename to keras_hub/src/layers/modeling/anchor_generator.py index a3c3800c49..2b467f72ff 100644 --- a/keras_hub/src/models/retinanet/anchor_generator.py +++ b/keras_hub/src/layers/modeling/anchor_generator.py @@ -5,9 +5,6 @@ from keras_hub.src.api_export import keras_hub_export -# TODO: https://github.com/keras-team/keras-hub/issues/1965 -from keras_hub.src.bounding_box.converters import convert_format - @keras_hub_export("keras_hub.layers.AnchorGenerator") class AnchorGenerator(keras.layers.Layer): @@ -133,10 +130,12 @@ def call(self, inputs): anchors = shifts + base_anchors anchors = ops.reshape(anchors, (-1, 4)) - multilevel_anchors[f"P{level}"] = convert_format( - anchors, - source="xyxy", - target=self.bounding_box_format, + multilevel_anchors[f"P{level}"] = ( + keras.utils.bounding_boxes.convert_format( + anchors, + source="xyxy", + target=self.bounding_box_format, + ) ) return multilevel_anchors diff --git a/keras_hub/src/models/retinanet/anchor_generator_test.py b/keras_hub/src/layers/modeling/anchor_generator_test.py similarity index 96% rename from keras_hub/src/models/retinanet/anchor_generator_test.py rename to keras_hub/src/layers/modeling/anchor_generator_test.py index 0b71630843..f3bc2510de 100644 --- a/keras_hub/src/models/retinanet/anchor_generator_test.py +++ b/keras_hub/src/layers/modeling/anchor_generator_test.py @@ -2,7 +2,7 @@ from absl.testing import parameterized from keras import ops -from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator +from keras_hub.src.layers.modeling.anchor_generator import AnchorGenerator from keras_hub.src.tests.test_case import TestCase diff --git a/keras_hub/src/models/retinanet/box_matcher.py b/keras_hub/src/layers/modeling/box_matcher.py similarity index 99% rename from keras_hub/src/models/retinanet/box_matcher.py rename to keras_hub/src/layers/modeling/box_matcher.py index dd8a486814..b841e8deb5 100644 --- a/keras_hub/src/models/retinanet/box_matcher.py +++ b/keras_hub/src/layers/modeling/box_matcher.py @@ -1,7 +1,10 @@ import keras from keras import ops +from keras_hub.src.api_export import keras_hub_export + +@keras_hub_export("keras_hub.layers.BoxMatcher") class BoxMatcher(keras.layers.Layer): """Box matching logic based on argmax of highest value (e.g., IOU). diff --git a/keras_hub/src/models/retinanet/box_matcher_test.py b/keras_hub/src/layers/modeling/box_matcher_test.py similarity index 98% rename from keras_hub/src/models/retinanet/box_matcher_test.py rename to keras_hub/src/layers/modeling/box_matcher_test.py index d991f90e5b..5fdf39a7ac 100644 --- a/keras_hub/src/models/retinanet/box_matcher_test.py +++ b/keras_hub/src/layers/modeling/box_matcher_test.py @@ -1,7 +1,7 @@ import numpy as np from keras import ops -from keras_hub.src.models.retinanet.box_matcher import BoxMatcher +from keras_hub.src.layers.modeling.box_matcher import BoxMatcher from keras_hub.src.tests.test_case import TestCase diff --git a/keras_hub/src/models/retinanet/non_max_supression.py b/keras_hub/src/layers/modeling/non_max_supression.py similarity index 94% rename from keras_hub/src/models/retinanet/non_max_supression.py rename to keras_hub/src/layers/modeling/non_max_supression.py index 5ca52b4dfc..70595492e5 100644 --- a/keras_hub/src/models/retinanet/non_max_supression.py +++ b/keras_hub/src/layers/modeling/non_max_supression.py @@ -2,22 +2,22 @@ import keras from keras import ops +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes import ( + validation, +) -# TODO: https://github.com/keras-team/keras-hub/issues/1965 -from keras_hub.src.bounding_box import converters -from keras_hub.src.bounding_box import utils -from keras_hub.src.bounding_box import validate_format +from keras_hub.src.api_export import keras_hub_export EPSILON = 1e-8 +@keras_hub_export("keras_hub.layers.NonMaxSupression") class NonMaxSuppression(keras.layers.Layer): """A Keras layer that decodes predictions of an object detection model. Args: bounding_box_format: The format of bounding boxes of input dataset. - Refer - TODO: link keras core bounding box docs + Refer: for more details on supported bounding box formats. from_logits: boolean, True means input score is logits, False means confidence. @@ -49,7 +49,10 @@ def __init__( self.built = True def call( - self, box_prediction, class_prediction, images=None, image_shape=None + self, + box_prediction, + class_prediction, + images=None, ): """Accepts images and raw scores, returning bounding box predictions. @@ -59,15 +62,24 @@ def call( class_prediction: Dense Tensor of shape [batch, boxes, num_classes]. """ target_format = "yxyx" - if utils.is_relative(self.bounding_box_format): - target_format = utils.as_relative(target_format) + height, width = None, None + + if "rel" in self.bounding_box_format and images is None: + raise ValueError( + "`images` cannot be None when using relative " + "bounding box format." + ) + + if "rel" in self.bounding_box_format: + target_format = "rel_" + target_format + height, width, _ = ops.shape(images) - box_prediction = converters.convert_format( + box_prediction = keras.utils.bounding_boxes.convert_format( box_prediction, source=self.bounding_box_format, target=target_format, - images=images, - image_shape=image_shape, + height=height, + width=width, ) if self.from_logits: class_prediction = ops.sigmoid(class_prediction) @@ -95,17 +107,17 @@ def call( class_prediction, ops.expand_dims(idx, axis=-1), axis=1 ) - box_prediction = converters.convert_format( + box_prediction = keras.utils.bounding_boxes.convert_format( box_prediction, source=target_format, target=self.bounding_box_format, - images=images, - image_shape=image_shape, + height=height, + width=width, ) bounding_boxes = { "boxes": box_prediction, "confidence": confidence_prediction, - "classes": ops.argmax(class_prediction, axis=-1), + "labels": ops.argmax(class_prediction, axis=-1), "num_detections": valid_det, } @@ -519,14 +531,8 @@ def mask_invalid_detections(bounding_boxes): returned value will also return `tf.RaggedTensor` representations. """ # ensure we are complying with Keras bounding box format. - info = validate_format.validate_format(bounding_boxes) - if info["ragged"]: - raise ValueError( - "`bounding_box.mask_invalid_detections()` requires inputs to be " - "Dense tensors. Please call " - "`bounding_box.to_dense(bounding_boxes)` before passing your boxes " - "to `bounding_box.mask_invalid_detections()`." - ) + validation.validate_bounding_boxes(bounding_boxes) + if "num_detections" not in bounding_boxes: raise ValueError( "`bounding_boxes` must have key 'num_detections' " @@ -534,7 +540,7 @@ def mask_invalid_detections(bounding_boxes): ) boxes = bounding_boxes.get("boxes") - classes = bounding_boxes.get("classes") + classes = bounding_boxes.get("labels") confidence = bounding_boxes.get("confidence", None) num_detections = bounding_boxes.get("num_detections") @@ -558,7 +564,7 @@ def mask_invalid_detections(bounding_boxes): result = bounding_boxes.copy() result["boxes"] = boxes - result["classes"] = classes + result["labels"] = classes if confidence is not None: result["confidence"] = confidence diff --git a/keras_hub/src/models/retinanet/non_max_supression_test.py b/keras_hub/src/layers/modeling/non_max_supression_test.py similarity index 88% rename from keras_hub/src/models/retinanet/non_max_supression_test.py rename to keras_hub/src/layers/modeling/non_max_supression_test.py index 94d3c3f124..4f310a0934 100644 --- a/keras_hub/src/models/retinanet/non_max_supression_test.py +++ b/keras_hub/src/layers/modeling/non_max_supression_test.py @@ -1,7 +1,7 @@ import numpy as np from keras import ops -from keras_hub.src.models.retinanet.non_max_supression import NonMaxSuppression +from keras_hub.src.layers.modeling.non_max_supression import NonMaxSuppression from keras_hub.src.tests.test_case import TestCase @@ -29,7 +29,7 @@ def test_confidence_threshold(self): self.assertAllClose( outputs["boxes"], [boxes[0][-2:, ...], boxes[1][:2, ...]] ) - self.assertAllClose(outputs["classes"], [[0.0, 0.0], [0.0, 0.0]]) + self.assertAllClose(outputs["labels"], [[0.0, 0.0], [0.0, 0.0]]) self.assertAllClose(outputs["confidence"], [[0.9, 0.5], [0.7, 0.5]]) def test_max_detections(self): @@ -55,5 +55,5 @@ def test_max_detections(self): self.assertAllClose( outputs["boxes"], [boxes[0][-1:, ...], boxes[1][:1, ...]] ) - self.assertAllClose(outputs["classes"], [[0.0], [0.0]]) + self.assertAllClose(outputs["labels"], [[0.0], [0.0]]) self.assertAllClose(outputs["confidence"], [[0.9], [0.7]]) diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder.py b/keras_hub/src/models/retinanet/retinanet_label_encoder.py index 66a6ff6d78..bfb74c83da 100644 --- a/keras_hub/src/models/retinanet/retinanet_label_encoder.py +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder.py @@ -7,7 +7,7 @@ from keras_hub.src.bounding_box.converters import convert_format from keras_hub.src.bounding_box.converters import encode_box_to_deltas from keras_hub.src.bounding_box.iou import compute_iou -from keras_hub.src.models.retinanet.box_matcher import BoxMatcher +from keras_hub.src.layers.modeling.box_matcher import BoxMatcher from keras_hub.src.utils import tensor_utils diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index 14b3a631c5..d8b9e17304 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -6,9 +6,9 @@ # TODO: https://github.com/keras-team/keras-hub/issues/1965 from keras_hub.src.bounding_box.converters import convert_format from keras_hub.src.bounding_box.converters import decode_deltas_to_boxes +from keras_hub.src.layers.modeling.anchor_generator import AnchorGenerator +from keras_hub.src.layers.modeling.non_max_supression import NonMaxSuppression from keras_hub.src.models.image_object_detector import ImageObjectDetector -from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator -from keras_hub.src.models.retinanet.non_max_supression import NonMaxSuppression from keras_hub.src.models.retinanet.prediction_head import PredictionHead from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone from keras_hub.src.models.retinanet.retinanet_label_encoder import ( From c677a70b7bc5441a658bbba080fc6f2fd12235e8 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 8 Jan 2025 13:44:45 -0800 Subject: [PATCH 32/33] Remove bounding box utils and depedencies of the utils --- keras_hub/api/bounding_box/__init__.py | 23 - keras_hub/src/bounding_box/__init__.py | 2 - keras_hub/src/bounding_box/converters.py | 606 ------------------ keras_hub/src/bounding_box/converters_test.py | 351 ---------- keras_hub/src/bounding_box/formats.py | 149 ----- keras_hub/src/bounding_box/iou.py | 251 -------- keras_hub/src/bounding_box/iou_test.py | 148 ----- keras_hub/src/bounding_box/to_dense.py | 81 --- keras_hub/src/bounding_box/to_dense_test.py | 23 - keras_hub/src/bounding_box/to_ragged.py | 86 --- keras_hub/src/bounding_box/to_ragged_test.py | 87 --- keras_hub/src/bounding_box/utils.py | 181 ------ keras_hub/src/bounding_box/utils_test.py | 155 ----- keras_hub/src/bounding_box/validate_format.py | 85 --- .../src/bounding_box/validate_format_test.py | 34 - .../src/layers/modeling/non_max_supression.py | 6 +- .../image_object_detector_preprocessor.py | 2 +- .../retinanet/retinanet_image_converter.py | 39 +- .../retinanet/retinanet_label_encoder.py | 25 +- .../retinanet/retinanet_label_encoder_test.py | 2 +- .../retinanet/retinanet_object_detector.py | 40 +- .../retinanet_object_detector_test.py | 10 +- 22 files changed, 70 insertions(+), 2316 deletions(-) delete mode 100644 keras_hub/api/bounding_box/__init__.py delete mode 100644 keras_hub/src/bounding_box/__init__.py delete mode 100644 keras_hub/src/bounding_box/converters.py delete mode 100644 keras_hub/src/bounding_box/converters_test.py delete mode 100644 keras_hub/src/bounding_box/formats.py delete mode 100644 keras_hub/src/bounding_box/iou.py delete mode 100644 keras_hub/src/bounding_box/iou_test.py delete mode 100644 keras_hub/src/bounding_box/to_dense.py delete mode 100644 keras_hub/src/bounding_box/to_dense_test.py delete mode 100644 keras_hub/src/bounding_box/to_ragged.py delete mode 100644 keras_hub/src/bounding_box/to_ragged_test.py delete mode 100644 keras_hub/src/bounding_box/utils.py delete mode 100644 keras_hub/src/bounding_box/utils_test.py delete mode 100644 keras_hub/src/bounding_box/validate_format.py delete mode 100644 keras_hub/src/bounding_box/validate_format_test.py diff --git a/keras_hub/api/bounding_box/__init__.py b/keras_hub/api/bounding_box/__init__.py deleted file mode 100644 index dfdea4305c..0000000000 --- a/keras_hub/api/bounding_box/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -"""DO NOT EDIT. - -This file was autogenerated. Do not edit it by hand, -since your modifications would be overwritten. -""" - -from keras_hub.src.bounding_box.converters import convert_format -from keras_hub.src.bounding_box.formats import CENTER_XYWH -from keras_hub.src.bounding_box.formats import REL_XYWH -from keras_hub.src.bounding_box.formats import REL_XYXY -from keras_hub.src.bounding_box.formats import REL_YXYX -from keras_hub.src.bounding_box.formats import XYWH -from keras_hub.src.bounding_box.formats import XYXY -from keras_hub.src.bounding_box.formats import YXYX -from keras_hub.src.bounding_box.iou import compute_ciou -from keras_hub.src.bounding_box.iou import compute_iou -from keras_hub.src.bounding_box.to_dense import to_dense -from keras_hub.src.bounding_box.to_ragged import to_ragged -from keras_hub.src.bounding_box.utils import as_relative -from keras_hub.src.bounding_box.utils import clip_boxes -from keras_hub.src.bounding_box.utils import clip_to_image -from keras_hub.src.bounding_box.utils import is_relative -from keras_hub.src.bounding_box.validate_format import validate_format diff --git a/keras_hub/src/bounding_box/__init__.py b/keras_hub/src/bounding_box/__init__.py deleted file mode 100644 index 78f451fd0d..0000000000 --- a/keras_hub/src/bounding_box/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# TODO: Once all bounding boxes are moved to keras repostory remove the -# bounding box folder. diff --git a/keras_hub/src/bounding_box/converters.py b/keras_hub/src/bounding_box/converters.py deleted file mode 100644 index 7c347a9815..0000000000 --- a/keras_hub/src/bounding_box/converters.py +++ /dev/null @@ -1,606 +0,0 @@ -"""Converter functions for working with bounding box formats.""" - -import keras -from keras import ops - -from keras_hub.src.api_export import keras_hub_export - -try: - import tensorflow as tf -except ImportError: - tf = None - - -# Internal exception to propagate the fact images was not passed to a converter -# that needs it. -class RequiresImagesException(Exception): - pass - - -ALL_AXES = 4 - - -def encode_box_to_deltas( - anchors, - boxes, - anchor_format, - box_format, - encoding_format="center_yxhw", - variance=None, - image_shape=None, -): - """Encodes bounding boxes relative to anchors as deltas. - - This function calculates the deltas that represent the difference between - bounding boxes and provided anchors. Deltas encode the offsets and scaling - factors to apply to anchors to obtain the target boxes. - - Boxes and anchors are first converted to the specified `encoding_format` - (defaulting to `center_yxhw`) for consistent delta representation. - - Args: - anchors: `Tensors`. Anchor boxes with shape of `(N, 4)` where N is the - number of anchors. - boxes: `Tensors` Bounding boxes to encode. Boxes can be of shape - `(B, N, 4)` or `(N, 4)`. - anchor_format: str. The format of the input `anchors` - (e.g., "xyxy", "xywh", etc.). - box_format: str. The format of the input `boxes` - (e.g., "xyxy", "xywh", etc.). - encoding_format: str. The intermediate format to which boxes and anchors - are converted before delta calculation. Defaults to "center_yxhw". - variance: `List[float]`. A 4-element array/tensor representing variance - factors to scale the box deltas. If provided, the calculated deltas - are divided by the variance. Defaults to None. - image_shape: `Tuple[int]`. The shape of the image (height, width, 3). - When using relative bounding box format for `box_format` the - `image_shape` is used for normalization. - Returns: - Encoded box deltas. The return type matches the `encode_format`. - - Raises: - ValueError: If `variance` is not None and its length is not 4. - ValueError: If `encoding_format` is not `"center_xywh"` or - `"center_yxhw"`. - - """ - if variance is not None: - variance = ops.convert_to_tensor(variance, "float32") - var_len = variance.shape[-1] - - if var_len != 4: - raise ValueError(f"`variance` must be length 4, got {variance}") - - if encoding_format not in ["center_xywh", "center_yxhw"]: - raise ValueError( - "`encoding_format` should be one of 'center_xywh' or " - f"'center_yxhw', got {encoding_format}" - ) - - encoded_anchors = convert_format( - anchors, - source=anchor_format, - target=encoding_format, - image_shape=image_shape, - ) - boxes = convert_format( - boxes, - source=box_format, - target=encoding_format, - image_shape=image_shape, - ) - anchor_dimensions = ops.maximum( - encoded_anchors[..., 2:], keras.backend.epsilon() - ) - box_dimensions = ops.maximum(boxes[..., 2:], keras.backend.epsilon()) - # anchors be unbatched, boxes can either be batched or unbatched. - boxes_delta = ops.concatenate( - [ - (boxes[..., :2] - encoded_anchors[..., :2]) / anchor_dimensions, - ops.log(box_dimensions / anchor_dimensions), - ], - axis=-1, - ) - if variance is not None: - boxes_delta /= variance - return boxes_delta - - -def decode_deltas_to_boxes( - anchors, - boxes_delta, - anchor_format, - box_format, - encoded_format="center_yxhw", - variance=None, - image_shape=None, -): - """Converts bounding boxes from delta format to the specified `box_format`. - - This function decodes bounding box deltas relative to anchors to obtain the - final bounding box coordinates. The boxes are encoded in a specific - `encoded_format` (center_yxhw by default) during the decoding process. - This allows flexibility in how the deltas are applied to the anchors. - - Args: - anchors: Can be `Tensors` or `Dict[Tensors]` where keys are level - indices and values are corresponding anchor boxes. - The shape of the array/tensor should be `(N, 4)` where N is the - number of anchors. - boxes_delta Can be `Tensors` or `Dict[Tensors]` Bounding box deltas - must have the same type and structure as `anchors`. The - shape of the array/tensor can be `(N, 4)` or `(B, N, 4)` where N is - the number of boxes. - anchor_format: str. The format of the input `anchors`. - (e.g., `"xyxy"`, `"xywh"`, etc.) - box_format: str. The desired format for the output boxes. - (e.g., `"xyxy"`, `"xywh"`, etc.) - encoded_format: str. Raw output format from regression head. Defaults - to `"center_yxhw"`. - variance: `List[floats]`. A 4-element array/tensor representing - variance factors to scale the box deltas. If provided, the deltas - are multiplied by the variance before being applied to the anchors. - Defaults to None. - image_shape: The shape of the image (height, width). This is needed - if normalization to image size is required when converting between - formats. Defaults to None. - - Returns: - Decoded box coordinates. The return type matches the `box_format`. - - Raises: - ValueError: If `variance` is not None and its length is not 4. - ValueError: If `encoded_format` is not `"center_xywh"` or - `"center_yxhw"`. - - """ - if variance is not None: - variance = ops.convert_to_tensor(variance, "float32") - var_len = variance.shape[-1] - - if var_len != 4: - raise ValueError(f"`variance` must be length 4, got {variance}") - - if encoded_format not in ["center_xywh", "center_yxhw"]: - raise ValueError( - f"`encoded_format` should be 'center_xywh' or 'center_yxhw', " - f"but got '{encoded_format}'." - ) - - def decode_single_level(anchor, box_delta): - encoded_anchor = convert_format( - anchor, - source=anchor_format, - target=encoded_format, - image_shape=image_shape, - ) - if variance is not None: - box_delta = box_delta * variance - # anchors be unbatched, boxes can either be batched or unbatched. - box = ops.concatenate( - [ - box_delta[..., :2] * encoded_anchor[..., 2:] - + encoded_anchor[..., :2], - ops.exp(box_delta[..., 2:]) * encoded_anchor[..., 2:], - ], - axis=-1, - ) - box = convert_format( - box, - source=encoded_format, - target=box_format, - image_shape=image_shape, - ) - return box - - if isinstance(anchors, dict) and isinstance(boxes_delta, dict): - boxes = {} - for lvl, anchor in anchors.items(): - boxes[lvl] = decode_single_level(anchor, boxes_delta[lvl]) - return boxes - else: - return decode_single_level(anchors, boxes_delta) - - -def _center_yxhw_to_xyxy(boxes, images=None, image_shape=None): - y, x, height, width = ops.split(boxes, ALL_AXES, axis=-1) - return ops.concatenate( - [x - width / 2.0, y - height / 2.0, x + width / 2.0, y + height / 2.0], - axis=-1, - ) - - -def _center_xywh_to_xyxy(boxes, images=None, image_shape=None): - x, y, width, height = ops.split(boxes, ALL_AXES, axis=-1) - return ops.concatenate( - [x - width / 2.0, y - height / 2.0, x + width / 2.0, y + height / 2.0], - axis=-1, - ) - - -def _xywh_to_xyxy(boxes, images=None, image_shape=None): - x, y, width, height = ops.split(boxes, ALL_AXES, axis=-1) - return ops.concatenate([x, y, x + width, y + height], axis=-1) - - -def _xyxy_to_center_yxhw(boxes, images=None, image_shape=None): - left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1) - return ops.concatenate( - [ - (top + bottom) / 2.0, - (left + right) / 2.0, - bottom - top, - right - left, - ], - axis=-1, - ) - - -def _rel_xywh_to_xyxy(boxes, images=None, image_shape=None): - image_height, image_width = _image_shape(images, image_shape, boxes) - x, y, width, height = ops.split(boxes, ALL_AXES, axis=-1) - return ops.concatenate( - [ - image_width * x, - image_height * y, - image_width * (x + width), - image_height * (y + height), - ], - axis=-1, - ) - - -def _xyxy_no_op(boxes, images=None, image_shape=None): - return boxes - - -def _xyxy_to_xywh(boxes, images=None, image_shape=None): - left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1) - return ops.concatenate( - [left, top, right - left, bottom - top], - axis=-1, - ) - - -def _xyxy_to_rel_xywh(boxes, images=None, image_shape=None): - image_height, image_width = _image_shape(images, image_shape, boxes) - left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1) - left, right = ( - left / image_width, - right / image_width, - ) - top, bottom = top / image_height, bottom / image_height - return ops.concatenate( - [left, top, right - left, bottom - top], - axis=-1, - ) - - -def _xyxy_to_center_xywh(boxes, images=None, image_shape=None): - left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1) - return ops.concatenate( - [ - (left + right) / 2.0, - (top + bottom) / 2.0, - right - left, - bottom - top, - ], - axis=-1, - ) - - -def _rel_xyxy_to_xyxy(boxes, images=None, image_shape=None): - image_height, image_width = _image_shape(images, image_shape, boxes) - left, top, right, bottom = ops.split( - boxes, - ALL_AXES, - axis=-1, - ) - left, right = left * image_width, right * image_width - top, bottom = top * image_height, bottom * image_height - return ops.concatenate( - [left, top, right, bottom], - axis=-1, - ) - - -def _xyxy_to_rel_xyxy(boxes, images=None, image_shape=None): - image_height, image_width = _image_shape(images, image_shape, boxes) - left, top, right, bottom = ops.split( - boxes, - ALL_AXES, - axis=-1, - ) - left, right = left / image_width, right / image_width - top, bottom = top / image_height, bottom / image_height - return ops.concatenate( - [left, top, right, bottom], - axis=-1, - ) - - -def _yxyx_to_xyxy(boxes, images=None, image_shape=None): - y1, x1, y2, x2 = ops.split(boxes, ALL_AXES, axis=-1) - return ops.concatenate([x1, y1, x2, y2], axis=-1) - - -def _rel_yxyx_to_xyxy(boxes, images=None, image_shape=None): - image_height, image_width = _image_shape(images, image_shape, boxes) - top, left, bottom, right = ops.split( - boxes, - ALL_AXES, - axis=-1, - ) - left, right = left * image_width, right * image_width - top, bottom = top * image_height, bottom * image_height - return ops.concatenate( - [left, top, right, bottom], - axis=-1, - ) - - -def _xyxy_to_yxyx(boxes, images=None, image_shape=None): - x1, y1, x2, y2 = ops.split(boxes, ALL_AXES, axis=-1) - return ops.concatenate([y1, x1, y2, x2], axis=-1) - - -def _xyxy_to_rel_yxyx(boxes, images=None, image_shape=None): - image_height, image_width = _image_shape(images, image_shape, boxes) - left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1) - left, right = left / image_width, right / image_width - top, bottom = top / image_height, bottom / image_height - return ops.concatenate( - [top, left, bottom, right], - axis=-1, - ) - - -TO_XYXY_CONVERTERS = { - "xywh": _xywh_to_xyxy, - "center_xywh": _center_xywh_to_xyxy, - "center_yxhw": _center_yxhw_to_xyxy, - "rel_xywh": _rel_xywh_to_xyxy, - "xyxy": _xyxy_no_op, - "rel_xyxy": _rel_xyxy_to_xyxy, - "yxyx": _yxyx_to_xyxy, - "rel_yxyx": _rel_yxyx_to_xyxy, -} - -FROM_XYXY_CONVERTERS = { - "xywh": _xyxy_to_xywh, - "center_xywh": _xyxy_to_center_xywh, - "center_yxhw": _xyxy_to_center_yxhw, - "rel_xywh": _xyxy_to_rel_xywh, - "xyxy": _xyxy_no_op, - "rel_xyxy": _xyxy_to_rel_xyxy, - "yxyx": _xyxy_to_yxyx, - "rel_yxyx": _xyxy_to_rel_yxyx, -} - - -@keras_hub_export("keras_hub.bounding_box.convert_format") -def convert_format( - boxes, source, target, images=None, image_shape=None, dtype="float32" -): - f"""Converts bounding_boxes from one format to another. - - Supported formats are: - - `"xyxy"`, also known as `corners` format. In this format the first four - axes represent `[left, top, right, bottom]` in that order. - - `"rel_xyxy"`. In this format, the axes are the same as `"xyxy"` but the x - coordinates are normalized using the image width, and the y axes the - image height. All values in `rel_xyxy` are in the range `(0, 1)`. - - `"xywh"`. In this format the first four axes represent - `[left, top, width, height]`. - - `"rel_xywh". In this format the first four axes represent - [left, top, width, height], just like `"xywh"`. Unlike `"xywh"`, the - values are in the range (0, 1) instead of absolute pixel values. - - `"center_xyWH"`. In this format the first two coordinates represent the x - and y coordinates of the center of the bounding box, while the last two - represent the width and height of the bounding box. - - `"center_yxHW"`. In this format the first two coordinates represent the y - and x coordinates of the center of the bounding box, while the last two - represent the height and width of the bounding box. - - `"yxyx"`. In this format the first four axes represent - [top, left, bottom, right] in that order. - - `"rel_yxyx"`. In this format, the axes are the same as `"yxyx"` but the x - coordinates are normalized using the image width, and the y axes the - image height. All values in `rel_yxyx` are in the range (0, 1). - Formats are case insensitive. It is recommended that you capitalize width - and height to maximize the visual difference between `"xyWH"` and `"xyxy"`. - - Relative formats, abbreviated `rel`, make use of the shapes of the `images` - passed. In these formats, the coordinates, widths, and heights are all - specified as percentages of the host image. `images` may be a ragged - Tensor. Note that using a ragged Tensor for images may cause a substantial - performance loss, as each image will need to be processed separately due to - the mismatching image shapes. - - Example: - - ```python - boxes = load_coco_dataset() - boxes_in_xywh = keras_hub.bounding_box.convert_format( - boxes, - source='xyxy', - target='xyWH' - ) - ``` - - Args: - boxes: tensor representing bounding boxes in the format specified in - the `source` parameter. `boxes` can optionally have extra - dimensions stacked on the final axis to store metadata. boxes - should be a 3D tensor, with the shape `[batch_size, num_boxes, 4]`. - Alternatively, boxes can be a dictionary with key 'boxes' containing - a tensor matching the aforementioned spec. - source:One of {" ".join([f'"{f}"' for f in TO_XYXY_CONVERTERS.keys()])}. - Used to specify the original format of the `boxes` parameter. - target:One of {" ".join([f'"{f}"' for f in TO_XYXY_CONVERTERS.keys()])}. - Used to specify the destination format of the `boxes` parameter. - images: (Optional) a batch of images aligned with `boxes` on the first - axis. Should be at least 3 dimensions, with the first 3 dimensions - representing: `[batch_size, height, width]`. Used in some - converters to compute relative pixel values of the bounding box - dimensions. Required when transforming from a rel format to a - non-rel format. - dtype: the data type to use when transforming the boxes, defaults to - `"float32"`. - """ - if isinstance(boxes, dict): - converted_boxes = boxes.copy() - converted_boxes["boxes"] = convert_format( - boxes["boxes"], - source=source, - target=target, - images=images, - image_shape=image_shape, - dtype=dtype, - ) - return converted_boxes - - if boxes.shape[-1] is not None and boxes.shape[-1] != 4: - raise ValueError( - "Expected `boxes` to be a Tensor with a final dimension of " - f"`4`. Instead, got `boxes.shape={boxes.shape}`." - ) - if images is not None and image_shape is not None: - raise ValueError( - "convert_format() expects either `images` or `image_shape`, but " - f"not both. Received images={images} image_shape={image_shape}" - ) - - _validate_image_shape(image_shape) - - source = source.lower() - target = target.lower() - if source not in TO_XYXY_CONVERTERS: - raise ValueError( - "`convert_format()` received an unsupported format for the " - "argument `source`. `source` should be one of " - f"{TO_XYXY_CONVERTERS.keys()}. Got source={source}" - ) - if target not in FROM_XYXY_CONVERTERS: - raise ValueError( - "`convert_format()` received an unsupported format for the " - "argument `target`. `target` should be one of " - f"{FROM_XYXY_CONVERTERS.keys()}. Got target={target}" - ) - - boxes = ops.cast(boxes, dtype) - if source == target: - return boxes - - # rel->rel conversions should not require images - if source.startswith("rel") and target.startswith("rel"): - source = source.replace("rel_", "", 1) - target = target.replace("rel_", "", 1) - - boxes, images, squeeze = _format_inputs(boxes, images) - to_xyxy_fn = TO_XYXY_CONVERTERS[source] - from_xyxy_fn = FROM_XYXY_CONVERTERS[target] - - try: - in_xyxy = to_xyxy_fn(boxes, images=images, image_shape=image_shape) - result = from_xyxy_fn(in_xyxy, images=images, image_shape=image_shape) - except RequiresImagesException: - raise ValueError( - "convert_format() must receive `images` or `image_shape` when " - "transforming between relative and absolute formats." - f"convert_format() received source=`{format}`, target=`{format}, " - f"but images={images} and image_shape={image_shape}." - ) - - return _format_outputs(result, squeeze) - - -def _format_inputs(boxes, images): - boxes_rank = len(boxes.shape) - if boxes_rank > 3: - raise ValueError( - "Expected len(boxes.shape)=2, or len(boxes.shape)=3, got " - f"len(boxes.shape)={boxes_rank}" - ) - boxes_includes_batch = boxes_rank == 3 - # Determine if images needs an expand_dims() call - if images is not None: - images_rank = len(images.shape) - if images_rank > 4: - raise ValueError( - "Expected len(images.shape)=2, or len(images.shape)=3, got " - f"len(images.shape)={images_rank}" - ) - images_include_batch = images_rank == 4 - if boxes_includes_batch != images_include_batch: - raise ValueError( - "convert_format() expects both boxes and images to be batched, " - "or both boxes and images to be unbatched. Received " - f"len(boxes.shape)={boxes_rank}, " - f"len(images.shape)={images_rank}. Expected either " - "len(boxes.shape)=2 AND len(images.shape)=3, or " - "len(boxes.shape)=3 AND len(images.shape)=4." - ) - if not images_include_batch: - images = ops.expand_dims(images, axis=0) - - if not boxes_includes_batch: - return ops.expand_dims(boxes, axis=0), images, True - return boxes, images, False - - -def _validate_image_shape(image_shape): - # Escape early if image_shape is None and skip validation. - if image_shape is None: - return - # tuple/list - if isinstance(image_shape, (tuple, list)): - if len(image_shape) != 3: - raise ValueError( - "image_shape should be of length 3, but got " - f"image_shape={image_shape}" - ) - return - - # tensor - if ops.is_tensor(image_shape): - if len(image_shape.shape) > 1: - raise ValueError( - "image_shape.shape should be (3), but got " - f"image_shape.shape={image_shape.shape}" - ) - if image_shape.shape[0] != 3: - raise ValueError( - "image_shape.shape should be (3), but got " - f"image_shape.shape={image_shape.shape}" - ) - return - - # Warn about failure cases - raise ValueError( - "Expected image_shape to be either a tuple, list, Tensor. " - f"Received image_shape={image_shape}" - ) - - -def _format_outputs(boxes, squeeze): - if squeeze: - return ops.squeeze(boxes, axis=0) - return boxes - - -def _image_shape(images, image_shape, boxes): - if images is None and image_shape is None: - raise RequiresImagesException() - - if image_shape is None: - if not isinstance(images, tf.RaggedTensor): - image_shape = ops.shape(images) - height, width = image_shape[1], image_shape[2] - else: - height = ops.reshape(images.row_lengths(), (-1, 1)) - width = ops.reshape(ops.max(images.row_lengths(axis=2), 1), (-1, 1)) - height = ops.expand_dims(height, axis=-1) - width = ops.expand_dims(width, axis=-1) - else: - height, width = image_shape[0], image_shape[1] - return ops.cast(height, boxes.dtype), ops.cast(width, boxes.dtype) diff --git a/keras_hub/src/bounding_box/converters_test.py b/keras_hub/src/bounding_box/converters_test.py deleted file mode 100644 index 9617a2a2aa..0000000000 --- a/keras_hub/src/bounding_box/converters_test.py +++ /dev/null @@ -1,351 +0,0 @@ -import itertools - -import numpy as np -import pytest -import tensorflow as tf -from absl.testing import parameterized -from keras import backend - -from keras_hub.src.bounding_box import converters -from keras_hub.src.bounding_box import to_dense -from keras_hub.src.bounding_box import to_ragged -from keras_hub.src.tests.test_case import TestCase - - -class ConvertersTestCase(TestCase): - def setUp(self): - xyxy_box = np.array( - [[[10, 20, 110, 120], [20, 30, 120, 130]]], dtype="float32" - ) - yxyx_box = np.array( - [[[20, 10, 120, 110], [30, 20, 130, 120]]], dtype="float32" - ) - rel_xyxy_box = np.array( - [[[0.01, 0.02, 0.11, 0.12], [0.02, 0.03, 0.12, 0.13]]], - dtype="float32", - ) - rel_xyxy_box_ragged_images = np.array( - [[[0.10, 0.20, 1.1, 1.20], [0.40, 0.6, 2.40, 2.6]]], dtype="float32" - ) - rel_yxyx_box = np.array( - [[[0.02, 0.01, 0.12, 0.11], [0.03, 0.02, 0.13, 0.12]]], - dtype="float32", - ) - rel_yxyx_box_ragged_images = np.array( - [[[0.2, 0.1, 1.2, 1.1], [0.6, 0.4, 2.6, 2.4]]], dtype="float32" - ) - center_xywh_box = np.array( - [[[60, 70, 100, 100], [70, 80, 100, 100]]], dtype="float32" - ) - xywh_box = np.array( - [[[10, 20, 100, 100], [20, 30, 100, 100]]], dtype="float32" - ) - rel_xywh_box = np.array( - [[[0.01, 0.02, 0.1, 0.1], [0.02, 0.03, 0.1, 0.1]]], dtype="float32" - ) - rel_xywh_box_ragged_images = np.array( - [[[0.1, 0.2, 1, 1], [0.4, 0.6, 2, 2]]], dtype="float32" - ) - - self.ragged_images = tf.ragged.constant( - [ - np.ones(shape=[100, 100, 3]), - np.ones(shape=[50, 50, 3]), - ], # 2 images - ragged_rank=2, - ) - - self.images = np.ones([2, 1000, 1000, 3]) - - self.ragged_classes = tf.ragged.constant([[0], [0]], dtype="float32") - - self.boxes = { - "xyxy": xyxy_box, - "center_xywh": center_xywh_box, - "rel_xywh": rel_xywh_box, - "xywh": xywh_box, - "rel_xyxy": rel_xyxy_box, - "yxyx": yxyx_box, - "rel_yxyx": rel_yxyx_box, - } - - self.boxes_ragged_images = { - "xyxy": xyxy_box, - "center_xywh": center_xywh_box, - "rel_xywh": rel_xywh_box_ragged_images, - "xywh": xywh_box, - "rel_xyxy": rel_xyxy_box_ragged_images, - "yxyx": yxyx_box, - "rel_yxyx": rel_yxyx_box_ragged_images, - } - - @parameterized.named_parameters( - *[ - (f"{source}_{target}", source, target) - for (source, target) in itertools.permutations( - [ - "xyxy", - "center_xywh", - "rel_xywh", - "xywh", - "rel_xyxy", - "yxyx", - "rel_yxyx", - ], - 2, - ) - ] - + [("xyxy_xyxy", "xyxy", "xyxy")] - ) - def test_converters(self, source, target): - source, target - source_box = self.boxes[source] - target_box = self.boxes[target] - - self.assertAllClose( - converters.convert_format( - source_box, source=source, target=target, images=self.images - ), - target_box, - ) - - @parameterized.named_parameters( - *[ - (f"{source}_{target}", source, target) - for (source, target) in itertools.permutations( - [ - "xyxy", - "center_xywh", - "rel_xywh", - "xywh", - "rel_xyxy", - "yxyx", - "rel_yxyx", - ], - 2, - ) - ] - + [("xyxy_xyxy", "xyxy", "xyxy")] - ) - @pytest.mark.skipif( - backend.backend() != "tensorflow", - reason="Only applies to backends which support raggeds", - ) - def test_converters_ragged_images(self, source, target): - source_box = _raggify(self.boxes_ragged_images[source]) - target_box = _raggify(self.boxes_ragged_images[target]) - self.assertAllClose( - converters.convert_format( - source_box, - source=source, - target=target, - images=self.ragged_images, - ), - target_box, - ) - - @parameterized.named_parameters( - *[ - (f"{source}_{target}", source, target) - for (source, target) in itertools.permutations( - [ - "xyxy", - "center_xywh", - "rel_xywh", - "xywh", - "rel_xyxy", - "yxyx", - "rel_yxyx", - ], - 2, - ) - ] - + [("xyxy_xyxy", "xyxy", "xyxy")] - ) - def test_converters_unbatched(self, source, target): - source_box = self.boxes[source][0] - target_box = self.boxes[target][0] - - self.assertAllClose( - converters.convert_format( - source_box, source=source, target=target, images=self.images[0] - ), - target_box, - ) - - def test_raises_with_different_image_rank(self): - source_box = self.boxes["xyxy"][0] - with self.assertRaises(ValueError): - converters.convert_format( - source_box, source="xyxy", target="xywh", images=self.images - ) - - def test_without_images(self): - source_box = self.boxes["xyxy"] - target_box = self.boxes["xywh"] - self.assertAllClose( - converters.convert_format(source_box, source="xyxy", target="xywh"), - target_box, - ) - - def test_rel_to_rel_without_images(self): - source_box = self.boxes["rel_xyxy"] - target_box = self.boxes["rel_yxyx"] - self.assertAllClose( - converters.convert_format( - source_box, source="rel_xyxy", target="rel_yxyx" - ), - target_box, - ) - - @parameterized.named_parameters( - *[ - (f"{source}_{target}", source, target) - for (source, target) in itertools.permutations( - [ - "xyxy", - "center_xywh", - "rel_xywh", - "xywh", - "rel_xyxy", - "yxyx", - "rel_yxyx", - ], - 2, - ) - ] - + [("xyxy_xyxy", "xyxy", "xyxy")] - ) - @pytest.mark.skipif( - backend.backend() != "tensorflow", - reason="Only applies to backends which support raggeds", - ) - def test_ragged_bounding_box(self, source, target): - source_box = _raggify(self.boxes[source]) - target_box = _raggify(self.boxes[target]) - self.assertAllClose( - converters.convert_format( - source_box, source=source, target=target, images=self.images - ), - target_box, - ) - - @parameterized.named_parameters( - *[ - (f"{source}_{target}", source, target) - for (source, target) in itertools.permutations( - [ - "xyxy", - "center_xywh", - "rel_xywh", - "xywh", - "rel_xyxy", - "yxyx", - "rel_yxyx", - ], - 2, - ) - ] - + [("xyxy_xyxy", "xyxy", "xyxy")] - ) - @pytest.mark.skipif( - backend.backend() != "tensorflow", - reason="Only applies to backends which support raggeds", - ) - def test_ragged_bounding_box_ragged_images(self, source, target): - source_box = _raggify(self.boxes_ragged_images[source]) - target_box = _raggify(self.boxes_ragged_images[target]) - self.assertAllClose( - converters.convert_format( - source_box, - source=source, - target=target, - images=self.ragged_images, - ), - target_box, - ) - - @parameterized.named_parameters( - *[ - (f"{source}_{target}", source, target) - for (source, target) in itertools.permutations( - [ - "xyxy", - "center_xywh", - "rel_xywh", - "xywh", - "rel_xyxy", - "yxyx", - "rel_yxyx", - ], - 2, - ) - ] - + [("xyxy_xyxy", "xyxy", "xyxy")] - ) - @pytest.mark.skipif( - backend.backend() != "tensorflow", - reason="Only applies to backends which support raggeds", - ) - def test_ragged_bounding_box_with_image_shape(self, source, target): - source_box = _raggify(self.boxes[source]) - target_box = _raggify(self.boxes[target]) - self.assertAllClose( - converters.convert_format( - source_box, - source=source, - target=target, - image_shape=(1000, 1000, 3), - ), - target_box, - ) - - @parameterized.named_parameters( - *[ - (f"{source}_{target}", source, target) - for (source, target) in itertools.permutations( - [ - "xyxy", - "center_xywh", - "rel_xywh", - "xywh", - "rel_xyxy", - "yxyx", - "rel_yxyx", - ], - 2, - ) - ] - + [("xyxy_xyxy", "xyxy", "xyxy")] - ) - @pytest.mark.skipif( - backend.backend() != "tensorflow", - reason="Only applies to backends which support raggeds", - ) - def test_dense_bounding_box_with_ragged_images(self, source, target): - source_box = _raggify(self.boxes_ragged_images[source]) - target_box = _raggify(self.boxes_ragged_images[target]) - source_bounding_boxes = { - "boxes": source_box, - "classes": self.ragged_classes, - } - source_bounding_boxes = to_dense.to_dense(source_bounding_boxes) - - result_bounding_boxes = converters.convert_format( - source_bounding_boxes, - source=source, - target=target, - images=self.ragged_images, - ) - result_bounding_boxes = to_ragged.to_ragged(result_bounding_boxes) - - self.assertAllClose( - result_bounding_boxes["boxes"], - target_box, - ) - - -def _raggify(tensor): - tensor = tf.squeeze(tensor, axis=0) - tensor = tf.RaggedTensor.from_row_lengths(tensor, [1, 1]) - return tensor diff --git a/keras_hub/src/bounding_box/formats.py b/keras_hub/src/bounding_box/formats.py deleted file mode 100644 index c8e50ab60a..0000000000 --- a/keras_hub/src/bounding_box/formats.py +++ /dev/null @@ -1,149 +0,0 @@ -""" -formats.py contains axis information for each supported format. -""" - -from keras_hub.src.api_export import keras_hub_export - - -@keras_hub_export("keras_hub.bounding_box.XYXY") -class XYXY: - """XYXY contains axis indices for the XYXY format. - - All values in the XYXY format should be absolute pixel values. - - The XYXY format consists of the following required indices: - - - LEFT: left of the bounding box - - TOP: top of the bounding box - - RIGHT: right of the bounding box - - BOTTOM: bottom of the bounding box - """ - - LEFT = 0 - TOP = 1 - RIGHT = 2 - BOTTOM = 3 - - -@keras_hub_export("keras_hub.bounding_box.REL_XYXY") -class REL_XYXY: - """REL_XYXY contains axis indices for the REL_XYXY format. - - REL_XYXY is like XYXY, but each value is relative to the width and height of - the origin image. Values are percentages of the origin images' width and - height respectively. - - The REL_XYXY format consists of the following required indices: - - - LEFT: left of the bounding box - - TOP: top of the bounding box - - RIGHT: right of the bounding box - - BOTTOM: bottom of the bounding box - """ - - LEFT = 0 - TOP = 1 - RIGHT = 2 - BOTTOM = 3 - - -@keras_hub_export("keras_hub.bounding_box.CENTER_XYWH") -class CENTER_XYWH: - """CENTER_XYWH contains axis indices for the CENTER_XYWH format. - - All values in the CENTER_XYWH format should be absolute pixel values. - - The CENTER_XYWH format consists of the following required indices: - - - X: X coordinate of the center of the bounding box - - Y: Y coordinate of the center of the bounding box - - WIDTH: width of the bounding box - - HEIGHT: height of the bounding box - """ - - X = 0 - Y = 1 - WIDTH = 2 - HEIGHT = 3 - - -@keras_hub_export("keras_hub.bounding_box.XYWH") -class XYWH: - """XYWH contains axis indices for the XYWH format. - - All values in the XYWH format should be absolute pixel values. - - The XYWH format consists of the following required indices: - - - X: X coordinate of the left of the bounding box - - Y: Y coordinate of the top of the bounding box - - WIDTH: width of the bounding box - - HEIGHT: height of the bounding box - """ - - X = 0 - Y = 1 - WIDTH = 2 - HEIGHT = 3 - - -@keras_hub_export("keras_hub.bounding_box.REL_XYWH") -class REL_XYWH: - """REL_XYWH contains axis indices for the XYWH format. - - REL_XYXY is like XYWH, but each value is relative to the width and height of - the origin image. Values are percentages of the origin images' width and - height respectively. - - - X: X coordinate of the left of the bounding box - - Y: Y coordinate of the top of the bounding box - - WIDTH: width of the bounding box - - HEIGHT: height of the bounding box - """ - - X = 0 - Y = 1 - WIDTH = 2 - HEIGHT = 3 - - -@keras_hub_export("keras_hub.bounding_box.YXYX") -class YXYX: - """YXYX contains axis indices for the YXYX format. - - All values in the YXYX format should be absolute pixel values. - - The YXYX format consists of the following required indices: - - - TOP: top of the bounding box - - LEFT: left of the bounding box - - BOTTOM: bottom of the bounding box - - RIGHT: right of the bounding box - """ - - TOP = 0 - LEFT = 1 - BOTTOM = 2 - RIGHT = 3 - - -@keras_hub_export("keras_hub.bounding_box.REL_YXYX") -class REL_YXYX: - """REL_YXYX contains axis indices for the REL_YXYX format. - - REL_YXYX is like YXYX, but each value is relative to the width and height of - the origin image. Values are percentages of the origin images' width and - height respectively. - - The REL_YXYX format consists of the following required indices: - - - TOP: top of the bounding box - - LEFT: left of the bounding box - - BOTTOM: bottom of the bounding box - - RIGHT: right of the bounding box - """ - - TOP = 0 - LEFT = 1 - BOTTOM = 2 - RIGHT = 3 diff --git a/keras_hub/src/bounding_box/iou.py b/keras_hub/src/bounding_box/iou.py deleted file mode 100644 index df2c907e4a..0000000000 --- a/keras_hub/src/bounding_box/iou.py +++ /dev/null @@ -1,251 +0,0 @@ -"""Contains functions to compute ious of bounding boxes.""" - -import math - -import keras -from keras import ops - -from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.bounding_box.converters import convert_format -from keras_hub.src.bounding_box.utils import as_relative -from keras_hub.src.bounding_box.utils import is_relative - - -def _compute_area(box): - """Computes area for bounding boxes - - Args: - box: [N, 4] or [batch_size, N, 4] float Tensor, either batched - or unbatched boxes. - Returns: - a float Tensor of [N] or [batch_size, N] - """ - y_min, x_min, y_max, x_max = ops.split(box[..., :4], 4, axis=-1) - return ops.squeeze((y_max - y_min) * (x_max - x_min), axis=-1) - - -def _compute_intersection(boxes1, boxes2): - """Computes intersection area between two sets of boxes. - - Args: - boxes1: [N, 4] or [batch_size, N, 4] float Tensor boxes. - boxes2: [M, 4] or [batch_size, M, 4] float Tensor boxes. - Returns: - a [N, M] or [batch_size, N, M] float Tensor. - """ - y_min1, x_min1, y_max1, x_max1 = ops.split(boxes1[..., :4], 4, axis=-1) - y_min2, x_min2, y_max2, x_max2 = ops.split(boxes2[..., :4], 4, axis=-1) - boxes2_rank = len(boxes2.shape) - perm = [1, 0] if boxes2_rank == 2 else [0, 2, 1] - # [N, M] or [batch_size, N, M] - intersect_ymax = ops.minimum(y_max1, ops.transpose(y_max2, perm)) - intersect_ymin = ops.maximum(y_min1, ops.transpose(y_min2, perm)) - intersect_xmax = ops.minimum(x_max1, ops.transpose(x_max2, perm)) - intersect_xmin = ops.maximum(x_min1, ops.transpose(x_min2, perm)) - - intersect_height = intersect_ymax - intersect_ymin - intersect_width = intersect_xmax - intersect_xmin - zeros_t = ops.cast(0, intersect_height.dtype) - intersect_height = ops.maximum(zeros_t, intersect_height) - intersect_width = ops.maximum(zeros_t, intersect_width) - - return intersect_height * intersect_width - - -@keras_hub_export("keras_hub.bounding_box.compute_iou") -def compute_iou( - boxes1, - boxes2, - bounding_box_format, - use_masking=False, - mask_val=-1, - images=None, - image_shape=None, -): - """Computes a lookup table vector containing the ious for a given set boxes. - - The lookup vector is to be indexed by [`boxes1_index`,`boxes2_index`] if - boxes are unbatched and by [`batch`, `boxes1_index`,`boxes2_index`] if the - boxes are batched. - - The users can pass `boxes1` and `boxes2` to be different ranks. For example: - 1) `boxes1`: [batch_size, M, 4], `boxes2`: [batch_size, N, 4] -> return - [batch_size, M, N]. - 2) `boxes1`: [batch_size, M, 4], `boxes2`: [N, 4] -> return - [batch_size, M, N] - 3) `boxes1`: [M, 4], `boxes2`: [batch_size, N, 4] -> return - [batch_size, M, N] - 4) `boxes1`: [M, 4], `boxes2`: [N, 4] -> return [M, N] - - Args: - boxes1: a list of bounding boxes in 'corners' format. Can be batched or - unbatched. - boxes2: a list of bounding boxes in 'corners' format. Can be batched or - unbatched. - bounding_box_format: a case-insensitive string which is one of `"xyxy"`, - `"rel_xyxy"`, `"xyWH"`, `"center_xyWH"`, `"yxyx"`, `"rel_yxyx"`. - For detailed information on the supported format, see the - [KerasCV bounding box documentation](https://keras.io/api/keras_cv/bounding_box/formats/). - use_masking: whether masking will be applied. This will mask all `boxes1` - or `boxes2` that have values less than 0 in all its 4 dimensions. - Default to `False`. - mask_val: int to mask those returned IOUs if the masking is True, defaults - to -1. - - Returns: - iou_lookup_table: a vector containing the pairwise ious of boxes1 and - boxes2. - """ # noqa: E501 - - boxes1_rank = len(boxes1.shape) - boxes2_rank = len(boxes2.shape) - - if boxes1_rank not in [2, 3]: - raise ValueError( - "compute_iou() expects boxes1 to be batched, or to be unbatched. " - f"Received len(boxes1.shape)={boxes1_rank}, " - f"len(boxes2.shape)={boxes2_rank}. Expected either " - "len(boxes1.shape)=2 AND or len(boxes1.shape)=3." - ) - if boxes2_rank not in [2, 3]: - raise ValueError( - "compute_iou() expects boxes2 to be batched, or to be unbatched. " - f"Received len(boxes1.shape)={boxes1_rank}, " - f"len(boxes2.shape)={boxes2_rank}. Expected either " - "len(boxes2.shape)=2 AND or len(boxes2.shape)=3." - ) - - target_format = "yxyx" - if is_relative(bounding_box_format): - target_format = as_relative(target_format) - - boxes1 = convert_format( - boxes1, - source=bounding_box_format, - target=target_format, - images=images, - image_shape=image_shape, - ) - - boxes2 = convert_format( - boxes2, - source=bounding_box_format, - target=target_format, - images=images, - image_shape=image_shape, - ) - - intersect_area = _compute_intersection(boxes1, boxes2) - boxes1_area = _compute_area(boxes1) - boxes2_area = _compute_area(boxes2) - boxes2_area_rank = len(boxes2_area.shape) - boxes2_axis = 1 if (boxes2_area_rank == 2) else 0 - boxes1_area = ops.expand_dims(boxes1_area, axis=-1) - boxes2_area = ops.expand_dims(boxes2_area, axis=boxes2_axis) - union_area = boxes1_area + boxes2_area - intersect_area - res = ops.divide(intersect_area, union_area + keras.backend.epsilon()) - - if boxes1_rank == 2: - perm = [1, 0] - else: - perm = [0, 2, 1] - - if not use_masking: - return res - - mask_val_t = ops.cast(mask_val, res.dtype) * ops.ones_like(res) - boxes1_mask = ops.less(ops.max(boxes1, axis=-1, keepdims=True), 0.0) - boxes2_mask = ops.less(ops.max(boxes2, axis=-1, keepdims=True), 0.0) - background_mask = ops.logical_or( - boxes1_mask, ops.transpose(boxes2_mask, perm) - ) - iou_lookup_table = ops.where(background_mask, mask_val_t, res) - return iou_lookup_table - - -@keras_hub_export("keras_hub.bounding_box.compute_ciou") -def compute_ciou(boxes1, boxes2, bounding_box_format): - """ - Computes the Complete IoU (CIoU) between two bounding boxes or between - two batches of bounding boxes. - - CIoU loss is an extension of GIoU loss, which further improves the IoU - optimization for object detection. CIoU loss not only penalizes the - bounding box coordinates but also considers the aspect ratio and center - distance of the boxes. The length of the last dimension should be 4 to - represent the bounding boxes. - - Args: - box1 (tensor): tensor representing the first bounding box with - shape (..., 4). - box2 (tensor): tensor representing the second bounding box with - shape (..., 4). - bounding_box_format: a case-insensitive string (for example, "xyxy"). - Each bounding box is defined by these 4 values. For detailed - information on the supported formats, see the [KerasCV bounding box - documentation](https://keras.io/api/keras_cv/bounding_box/formats/). - - Returns: - tensor: The CIoU distance between the two bounding boxes. - """ - target_format = "xyxy" - if is_relative(bounding_box_format): - target_format = as_relative(target_format) - - boxes1 = convert_format( - boxes1, source=bounding_box_format, target=target_format - ) - - boxes2 = convert_format( - boxes2, source=bounding_box_format, target=target_format - ) - - x_min1, y_min1, x_max1, y_max1 = ops.split(boxes1[..., :4], 4, axis=-1) - x_min2, y_min2, x_max2, y_max2 = ops.split(boxes2[..., :4], 4, axis=-1) - - width_1 = x_max1 - x_min1 - height_1 = y_max1 - y_min1 + keras.backend.epsilon() - width_2 = x_max2 - x_min2 - height_2 = y_max2 - y_min2 + keras.backend.epsilon() - - intersection_area = ops.maximum( - ops.minimum(x_max1, x_max2) - ops.maximum(x_min1, x_min2), 0 - ) * ops.maximum( - ops.minimum(y_max1, y_max2) - ops.maximum(y_min1, y_min2), 0 - ) - union_area = ( - width_1 * height_1 - + width_2 * height_2 - - intersection_area - + keras.backend.epsilon() - ) - iou = ops.squeeze( - ops.divide(intersection_area, union_area + keras.backend.epsilon()), - axis=-1, - ) - - convex_width = ops.maximum(x_max1, x_max2) - ops.minimum(x_min1, x_min2) - convex_height = ops.maximum(y_max1, y_max2) - ops.minimum(y_min1, y_min2) - convex_diagonal_squared = ops.squeeze( - convex_width**2 + convex_height**2 + keras.backend.epsilon(), - axis=-1, - ) - centers_distance_squared = ops.squeeze( - ((x_min1 + x_max1) / 2 - (x_min2 + x_max2) / 2) ** 2 - + ((y_min1 + y_max1) / 2 - (y_min2 + y_max2) / 2) ** 2, - axis=-1, - ) - - v = ops.squeeze( - ops.power( - (4 / math.pi**2) - * (ops.arctan(width_2 / height_2) - ops.arctan(width_1 / height_1)), - 2, - ), - axis=-1, - ) - alpha = v / (v - iou + (1 + keras.backend.epsilon())) - - return iou - ( - centers_distance_squared / convex_diagonal_squared + v * alpha - ) diff --git a/keras_hub/src/bounding_box/iou_test.py b/keras_hub/src/bounding_box/iou_test.py deleted file mode 100644 index 2e00f24869..0000000000 --- a/keras_hub/src/bounding_box/iou_test.py +++ /dev/null @@ -1,148 +0,0 @@ -"""Tests for iou functions.""" - -import numpy as np - -from keras_hub.src.bounding_box import iou as iou_lib -from keras_hub.src.tests.test_case import TestCase - - -class IoUTest(TestCase): - def test_compute_single_iou(self): - bb1 = np.array([[100, 101, 200, 201]]) - bb1_off_by_1 = np.array([[101, 102, 201, 202]]) - # area of bb1 and bb1_off_by_1 are each 10000. - # intersection area is 99*99=9801 - # iou=9801/(2*10000 - 9801)=0.96097656633 - self.assertAllClose( - iou_lib.compute_iou(bb1, bb1_off_by_1, "yxyx")[0], [0.96097656633] - ) - - def test_compute_iou(self): - bb1 = [100, 101, 200, 201] - bb1_off_by_1_pred = [101, 102, 201, 202] - iou_bb1_bb1_off = 0.96097656633 - top_left_bounding_box = [0, 2, 1, 3] - far_away_box = [1300, 1400, 1500, 1401] - another_far_away_pred = [1000, 1400, 1200, 1401] - - # Rows represent predictions, columns ground truths - expected_result = np.array( - [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], - dtype=np.float32, - ) - - sample_y_true = np.array([bb1, top_left_bounding_box, far_away_box]) - sample_y_pred = np.array( - [bb1_off_by_1_pred, top_left_bounding_box, another_far_away_pred], - ) - - result = iou_lib.compute_iou(sample_y_true, sample_y_pred, "yxyx") - self.assertAllClose(expected_result, result) - - def test_batched_compute_iou(self): - bb1 = [100, 101, 200, 201] - bb1_off_by_1_pred = [101, 102, 201, 202] - iou_bb1_bb1_off = 0.96097656633 - top_left_bounding_box = [0, 2, 1, 3] - far_away_box = [1300, 1400, 1500, 1401] - another_far_away_pred = [1000, 1400, 1200, 1401] - - # Rows represent predictions, columns ground truths - expected_result = np.array( - [ - [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], - [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], - ], - ) - - sample_y_true = np.array( - [ - [bb1, top_left_bounding_box, far_away_box], - [bb1, top_left_bounding_box, far_away_box], - ], - ) - sample_y_pred = np.array( - [ - [ - bb1_off_by_1_pred, - top_left_bounding_box, - another_far_away_pred, - ], - [ - bb1_off_by_1_pred, - top_left_bounding_box, - another_far_away_pred, - ], - ], - ) - - result = iou_lib.compute_iou(sample_y_true, sample_y_pred, "yxyx") - self.assertAllClose(expected_result, result) - - def test_batched_boxes1_unbatched_boxes2(self): - bb1 = [100, 101, 200, 201] - bb1_off_by_1_pred = [101, 102, 201, 202] - iou_bb1_bb1_off = 0.96097656633 - top_left_bounding_box = [0, 2, 1, 3] - far_away_box = [1300, 1400, 1500, 1401] - another_far_away_pred = [1000, 1400, 1200, 1401] - - # Rows represent predictions, columns ground truths - expected_result = np.array( - [ - [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], - [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], - ], - ) - - sample_y_true = np.array( - [ - [bb1, top_left_bounding_box, far_away_box], - [bb1, top_left_bounding_box, far_away_box], - ], - ) - sample_y_pred = np.array( - [bb1_off_by_1_pred, top_left_bounding_box, another_far_away_pred], - ) - - result = iou_lib.compute_iou(sample_y_true, sample_y_pred, "yxyx") - self.assertAllClose(expected_result, result) - - def test_unbatched_boxes1_batched_boxes2(self): - bb1 = [100, 101, 200, 201] - bb1_off_by_1_pred = [101, 102, 201, 202] - iou_bb1_bb1_off = 0.96097656633 - top_left_bounding_box = [0, 2, 1, 3] - far_away_box = [1300, 1400, 1500, 1401] - another_far_away_pred = [1000, 1400, 1200, 1401] - - # Rows represent predictions, columns ground truths - expected_result = np.array( - [ - [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], - [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], - ], - ) - - sample_y_true = np.array( - [ - [bb1, top_left_bounding_box, far_away_box], - ], - ) - sample_y_pred = np.array( - [ - [ - bb1_off_by_1_pred, - top_left_bounding_box, - another_far_away_pred, - ], - [ - bb1_off_by_1_pred, - top_left_bounding_box, - another_far_away_pred, - ], - ], - ) - - result = iou_lib.compute_iou(sample_y_true, sample_y_pred, "yxyx") - self.assertAllClose(expected_result, result) diff --git a/keras_hub/src/bounding_box/to_dense.py b/keras_hub/src/bounding_box/to_dense.py deleted file mode 100644 index 68b00d065f..0000000000 --- a/keras_hub/src/bounding_box/to_dense.py +++ /dev/null @@ -1,81 +0,0 @@ -import keras_hub.src.bounding_box.validate_format as validate_format -from keras_hub.src.api_export import keras_hub_export - -try: - import tensorflow as tf -except ImportError: - tf = None - - -def _box_shape(batched, boxes_shape, max_boxes): - # ensure we dont drop the final axis in RaggedTensor mode - if max_boxes is None: - shape = list(boxes_shape) - shape[-1] = 4 - return shape - if batched: - return [None, max_boxes, 4] - return [max_boxes, 4] - - -def _classes_shape(batched, classes_shape, max_boxes): - if max_boxes is None: - return None - if batched: - return [None, max_boxes] + classes_shape[2:] - return [max_boxes] + classes_shape[2:] - - -@keras_hub_export("keras_hub.bounding_box.to_dense") -def to_dense(bounding_boxes, max_boxes=None, default_value=-1): - """to_dense converts bounding boxes to Dense tensors - - Args: - bounding_boxes: bounding boxes in KerasCV dictionary format. - max_boxes: the maximum number of boxes, used to pad tensors to a given - shape. This can be used to make object detection pipelines TPU - compatible. - default_value: the default value to pad bounding boxes with. defaults - to -1. - """ - info = validate_format.validate_format(bounding_boxes) - - # guards against errors in metrics regarding modification of inputs. - # also guards against unexpected behavior when modifying downstream - bounding_boxes = bounding_boxes.copy() - - # Already running in masked mode - if not info["ragged"]: - # even if already ragged, still copy the dictionary for API consistency - return bounding_boxes - - if isinstance(bounding_boxes["classes"], tf.RaggedTensor): - bounding_boxes["classes"] = bounding_boxes["classes"].to_tensor( - default_value=default_value, - shape=_classes_shape( - info["is_batched"], bounding_boxes["classes"].shape, max_boxes - ), - ) - - if isinstance(bounding_boxes["boxes"], tf.RaggedTensor): - bounding_boxes["boxes"] = bounding_boxes["boxes"].to_tensor( - default_value=default_value, - shape=_box_shape( - info["is_batched"], bounding_boxes["boxes"].shape, max_boxes - ), - ) - - if "confidence" in bounding_boxes: - if isinstance(bounding_boxes["confidence"], tf.RaggedTensor): - bounding_boxes["confidence"] = bounding_boxes[ - "confidence" - ].to_tensor( - default_value=default_value, - shape=_classes_shape( - info["is_batched"], - bounding_boxes["confidence"].shape, - max_boxes, - ), - ) - - return bounding_boxes diff --git a/keras_hub/src/bounding_box/to_dense_test.py b/keras_hub/src/bounding_box/to_dense_test.py deleted file mode 100644 index 91acb8137a..0000000000 --- a/keras_hub/src/bounding_box/to_dense_test.py +++ /dev/null @@ -1,23 +0,0 @@ -import pytest -import tensorflow as tf -from keras import backend - -from keras_hub.src.bounding_box import to_dense -from keras_hub.src.tests.test_case import TestCase - - -class ToDenseTest(TestCase): - @pytest.mark.skipif( - backend.backend() != "tensorflow", - reason="Only applies to backends which support raggeds", - ) - def test_converts_to_dense(self): - bounding_boxes = { - "boxes": tf.ragged.constant( - [[[0, 0, 1, 1]], [[0, 0, 1, 1], [0, 0, 1, 1], [0, 0, 1, 1]]] - ), - "classes": tf.ragged.constant([[0], [1, 2, 3]]), - } - bounding_boxes = to_dense.to_dense(bounding_boxes) - self.assertEqual(bounding_boxes["boxes"].shape, [2, 3, 4]) - self.assertEqual(bounding_boxes["classes"].shape, [2, 3]) diff --git a/keras_hub/src/bounding_box/to_ragged.py b/keras_hub/src/bounding_box/to_ragged.py deleted file mode 100644 index f86712dd35..0000000000 --- a/keras_hub/src/bounding_box/to_ragged.py +++ /dev/null @@ -1,86 +0,0 @@ -import keras - -import keras_hub.src.bounding_box.validate_format as validate_format -from keras_hub.src.api_export import keras_hub_export - -try: - import tensorflow as tf -except ImportError: - tf = None - - -@keras_hub_export("keras_hub.bounding_box.to_ragged") -def to_ragged(bounding_boxes, sentinel=-1, dtype="float32"): - """converts a Dense padded bounding box `tf.Tensor` to a `tf.RaggedTensor`. - - Bounding boxes are ragged tensors in most use cases. Converting them to a - dense tensor makes it easier to work with Tensorflow ecosystem. - This function can be used to filter out the masked out bounding boxes by - checking for padded sentinel value of the class_id axis of the - bounding_boxes. - - Example: - ```python - bounding_boxes = { - "boxes": tf.constant([[2, 3, 4, 5], [0, 1, 2, 3]]), - "classes": tf.constant([[-1, 1]]), - } - bounding_boxes = bounding_box.to_ragged(bounding_boxes) - print(bounding_boxes) - # { - # "boxes": [[0, 1, 2, 3]], - # "classes": [[1]] - # } - ``` - - Args: - bounding_boxes: a Tensor of bounding boxes. May be batched, or - unbatched. - sentinel: The value indicating that a bounding box does not exist at the - current index, and the corresponding box is padding, defaults to -1. - dtype: the data type to use for the underlying Tensors. - Returns: - dictionary of `tf.RaggedTensor` or 'tf.Tensor' containing the filtered - bounding boxes. - """ - if keras.config.backend() != "tensorflow": - raise NotImplementedError( - "`bounding_box.to_ragged` was called using a backend which does " - "not support ragged tensors. " - f"Current backend: {keras.backend.backend()}." - ) - - info = validate_format.validate_format(bounding_boxes) - - if info["ragged"]: - return bounding_boxes - - boxes = bounding_boxes.get("boxes") - classes = bounding_boxes.get("classes") - confidence = bounding_boxes.get("confidence", None) - - mask = classes != sentinel - - boxes = tf.ragged.boolean_mask(boxes, mask) - classes = tf.ragged.boolean_mask(classes, mask) - if confidence is not None: - confidence = tf.ragged.boolean_mask(confidence, mask) - - if isinstance(boxes, tf.Tensor): - boxes = tf.RaggedTensor.from_tensor(boxes) - - if isinstance(classes, tf.Tensor) and len(classes.shape) > 1: - classes = tf.RaggedTensor.from_tensor(classes) - - if confidence is not None: - if isinstance(confidence, tf.Tensor) and len(confidence.shape) > 1: - confidence = tf.RaggedTensor.from_tensor(confidence) - - result = bounding_boxes.copy() - result["boxes"] = tf.cast(boxes, dtype) - result["classes"] = tf.cast(classes, dtype) - - if confidence is not None: - result["confidence"] = tf.cast(confidence, dtype) - - return result diff --git a/keras_hub/src/bounding_box/to_ragged_test.py b/keras_hub/src/bounding_box/to_ragged_test.py deleted file mode 100644 index 9b76866ddc..0000000000 --- a/keras_hub/src/bounding_box/to_ragged_test.py +++ /dev/null @@ -1,87 +0,0 @@ -import numpy as np -import pytest -from keras import backend - -from keras_hub.src.bounding_box import to_dense -from keras_hub.src.bounding_box import to_ragged -from keras_hub.src.tests.test_case import TestCase - - -class ToRaggedTest(TestCase): - @pytest.mark.skipif( - backend.backend() != "tensorflow", - reason="Only applies to backends which support raggeds", - ) - def test_converts_to_ragged(self): - bounding_boxes = { - "boxes": np.array( - [[[0, 0, 0, 0], [0, 0, 0, 0]], [[2, 3, 4, 5], [0, 1, 2, 3]]] - ), - "classes": np.array([[-1, -1], [-1, 1]]), - "confidence": np.array([[0.5, 0.7], [0.23, 0.12]]), - } - bounding_boxes = to_ragged.to_ragged(bounding_boxes) - - self.assertEqual(bounding_boxes["boxes"][1].shape, [1, 4]) - self.assertEqual(bounding_boxes["classes"][1].shape, [1]) - self.assertEqual( - bounding_boxes["confidence"][1].shape, - [ - 1, - ], - ) - - self.assertEqual(bounding_boxes["classes"][0].shape, [0]) - self.assertEqual(bounding_boxes["boxes"][0].shape, [0, 4]) - self.assertEqual( - bounding_boxes["confidence"][0].shape, - [ - 0, - ], - ) - - @pytest.mark.skipif( - backend.backend() != "tensorflow", - reason="Only applies to backends which support raggeds", - ) - def test_round_trip(self): - original = { - "boxes": np.array( - [ - [[0, 0, 0, 0], [-1, -1, -1, -1]], - [[-1, -1, -1, -1], [-1, -1, -1, -1]], - ] - ), - "classes": np.array([[1, -1], [-1, -1]]), - "confidence": np.array([[0.5, -1], [-1, -1]]), - } - bounding_boxes = to_ragged.to_ragged(original) - bounding_boxes = to_dense.to_dense(bounding_boxes, max_boxes=2) - - self.assertEqual(bounding_boxes["boxes"][1].shape, [2, 4]) - self.assertEqual(bounding_boxes["classes"][1].shape, [2]) - self.assertEqual(bounding_boxes["classes"][0].shape, [2]) - self.assertEqual(bounding_boxes["boxes"][0].shape, [2, 4]) - self.assertEqual(bounding_boxes["confidence"][0].shape, [2]) - - self.assertAllEqual(bounding_boxes["boxes"], original["boxes"]) - self.assertAllEqual(bounding_boxes["classes"], original["classes"]) - self.assertAllEqual( - bounding_boxes["confidence"], original["confidence"] - ) - - @pytest.mark.skipif( - backend.backend() == "tensorflow", - reason="Only applies to backends which don't support raggeds", - ) - def test_backend_without_raggeds_throws(self): - bounding_boxes = { - "boxes": np.array( - [[[0, 0, 0, 0], [0, 0, 0, 0]], [[2, 3, 4, 5], [0, 1, 2, 3]]] - ), - "classes": np.array([[-1, -1], [-1, 1]]), - "confidence": np.array([[0.5, 0.7], [0.23, 0.12]]), - } - - with self.assertRaisesRegex(NotImplementedError, "support ragged"): - to_ragged.to_ragged(bounding_boxes) diff --git a/keras_hub/src/bounding_box/utils.py b/keras_hub/src/bounding_box/utils.py deleted file mode 100644 index ac4fe8d05b..0000000000 --- a/keras_hub/src/bounding_box/utils.py +++ /dev/null @@ -1,181 +0,0 @@ -"""Utility functions for working with bounding boxes.""" - -from keras import ops - -from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.bounding_box import converters -from keras_hub.src.bounding_box.formats import XYWH - - -@keras_hub_export("keras_hub.bounding_box.is_relative") -def is_relative(bounding_box_format): - """A util to check if a bounding box format uses relative coordinates""" - if bounding_box_format.lower() not in converters.TO_XYXY_CONVERTERS: - raise ValueError( - "`is_relative()` received an unsupported format for the argument " - f"`bounding_box_format`. `bounding_box_format` should be one of " - f"{converters.TO_XYXY_CONVERTERS.keys()}. " - f"Got bounding_box_format={bounding_box_format}" - ) - - return bounding_box_format.startswith("rel") - - -@keras_hub_export("keras_hub.bounding_box.as_relative") -def as_relative(bounding_box_format): - """A util to get the relative equivalent of a provided bounding box format. - - If the specified format is already a relative format, - it will be returned unchanged. - """ - - if not is_relative(bounding_box_format): - return "rel_" + bounding_box_format - - return bounding_box_format - - -def _relative_area(boxes, bounding_box_format): - boxes = converters.convert_format( - boxes, - source=bounding_box_format, - target="rel_xywh", - ) - widths = boxes[..., XYWH.WIDTH] - heights = boxes[..., XYWH.HEIGHT] - # handle corner case where shear performs a full inversion. - return ops.where( - ops.logical_and(widths > 0, heights > 0), widths * heights, 0.0 - ) - - -@keras_hub_export("keras_hub.bounding_box.clip_to_image") -def clip_to_image( - bounding_boxes, bounding_box_format, images=None, image_shape=None -): - """clips bounding boxes to image boundaries. - - `clip_to_image()` clips bounding boxes that have coordinates out of bounds - of an image down to the boundaries of the image. This is done by converting - the bounding box to relative formats, then clipping them to the `[0, 1]` - range. Additionally, bounding boxes that end up with a zero area have their - class ID set to -1, indicating that there is no object present in them. - - Args: - bounding_boxes: bounding box tensor to clip. - bounding_box_format: the KerasCV bounding box format the bounding boxes - are in. - images: list of images to clip the bounding boxes to. - image_shape: the shape of the images to clip the bounding boxes to. - """ - boxes, classes = bounding_boxes["boxes"], bounding_boxes["classes"] - - boxes = converters.convert_format( - boxes, - source=bounding_box_format, - target="rel_xyxy", - images=images, - image_shape=image_shape, - ) - boxes, classes, images, squeeze = _format_inputs(boxes, classes, images) - x1, y1, x2, y2 = ops.split(boxes, 4, axis=-1) - clipped_bounding_boxes = ops.concatenate( - [ - ops.clip(x1, 0, 1), - ops.clip(y1, 0, 1), - ops.clip(x2, 0, 1), - ops.clip(y2, 0, 1), - ], - axis=-1, - ) - areas = _relative_area( - clipped_bounding_boxes, bounding_box_format="rel_xyxy" - ) - clipped_bounding_boxes = converters.convert_format( - clipped_bounding_boxes, - source="rel_xyxy", - target=bounding_box_format, - images=images, - image_shape=image_shape, - ) - clipped_bounding_boxes = ops.where( - ops.expand_dims(areas > 0.0, axis=-1), clipped_bounding_boxes, -1.0 - ) - classes = ops.where(areas > 0.0, classes, -1) - nan_indices = ops.any(ops.isnan(clipped_bounding_boxes), axis=-1) - classes = ops.where(nan_indices, -1, classes) - - # TODO update dict and return - clipped_bounding_boxes, classes = _format_outputs( - clipped_bounding_boxes, classes, squeeze - ) - - bounding_boxes.update({"boxes": clipped_bounding_boxes, "classes": classes}) - - return bounding_boxes - - -@keras_hub_export("keras_hub.bounding_box.clip_boxes") -def clip_boxes(boxes, image_shape): - """Clip boxes to the boundaries of the image shape""" - if boxes.shape[-1] != 4: - raise ValueError( - "boxes.shape[-1] is {:d}, but must be 4.".format(boxes.shape[-1]) - ) - - if isinstance(image_shape, list) or isinstance(image_shape, tuple): - height, width, _ = image_shape - max_length = ops.stack([height, width, height, width], axis=-1) - else: - image_shape = ops.cast(image_shape, dtype=boxes.dtype) - height = image_shape[0] - width = image_shape[1] - max_length = ops.stack([height, width, height, width], axis=-1) - - clipped_boxes = ops.maximum(ops.minimum(boxes, max_length), 0.0) - return clipped_boxes - - -def _format_inputs(boxes, classes, images): - boxes_rank = len(boxes.shape) - if boxes_rank > 3: - raise ValueError( - "Expected len(boxes.shape)=2, or len(boxes.shape)=3, got " - f"len(boxes.shape)={boxes_rank}" - ) - boxes_includes_batch = boxes_rank == 3 - # Determine if images needs an expand_dims() call - if images is not None: - images_rank = len(images.shape) - if images_rank > 4: - raise ValueError( - "Expected len(images.shape)=2, or len(images.shape)=3, got " - f"len(images.shape)={images_rank}" - ) - images_include_batch = images_rank == 4 - if boxes_includes_batch != images_include_batch: - raise ValueError( - "clip_to_image() expects both boxes and images to be batched, " - "or both boxes and images to be unbatched. Received " - f"len(boxes.shape)={boxes_rank}, " - f"len(images.shape)={images_rank}. Expected either " - "len(boxes.shape)=2 AND len(images.shape)=3, or " - "len(boxes.shape)=3 AND len(images.shape)=4." - ) - if not images_include_batch: - images = ops.expand_dims(images, axis=0) - - if not boxes_includes_batch: - return ( - ops.expand_dims(boxes, axis=0), - ops.expand_dims(classes, axis=0), - images, - True, - ) - return boxes, classes, images, False - - -def _format_outputs(boxes, classes, squeeze): - if squeeze: - return ops.squeeze(boxes, axis=0), ops.squeeze(classes, axis=0) - return boxes, classes diff --git a/keras_hub/src/bounding_box/utils_test.py b/keras_hub/src/bounding_box/utils_test.py deleted file mode 100644 index 40ad8e6e07..0000000000 --- a/keras_hub/src/bounding_box/utils_test.py +++ /dev/null @@ -1,155 +0,0 @@ -import numpy as np -from keras import ops - -from keras_hub.src.bounding_box import utils -from keras_hub.src.tests.test_case import TestCase - - -class BoundingBoxUtilTest(TestCase): - def test_clip_to_image_standard(self): - # Test xyxy format unbatched - height = 256 - width = 256 - bounding_boxes = { - "boxes": np.array([[200, 200, 400, 400], [100, 100, 300, 300]]), - "classes": np.array([0, 0]), - } - image = ops.ones(shape=(height, width, 3)) - bounding_boxes = utils.clip_to_image( - bounding_boxes, bounding_box_format="xyxy", images=image - ) - boxes = bounding_boxes["boxes"] - self.assertAllGreaterEqual(ops.convert_to_numpy(boxes), 0) - ( - x1, - y1, - x2, - y2, - ) = ops.split(boxes, 4, axis=1) - self.assertAllLessEqual( - ops.convert_to_numpy(ops.concatenate([x1, x2], axis=1)), width - ) - self.assertAllLessEqual( - ops.convert_to_numpy(ops.concatenate([y1, y2], axis=1)), height - ) - # Test relative format batched - image = ops.ones(shape=(1, height, width, 3)) - - bounding_boxes = { - "boxes": np.array([[[0.2, -1, 1.2, 0.3], [0.4, 1.5, 0.2, 0.3]]]), - "classes": np.array([[0, 0]]), - } - bounding_boxes = utils.clip_to_image( - bounding_boxes, bounding_box_format="rel_xyxy", images=image - ) - boxes = bounding_boxes["boxes"] - self.assertAllLessEqual(ops.convert_to_numpy(boxes), 1) - - def test_clip_to_image_filters_fully_out_bounding_boxes(self): - # Test xyxy format unbatched - height = 256 - width = 256 - bounding_boxes = { - "boxes": np.array([[257, 257, 400, 400], [100, 100, 300, 300]]), - "classes": np.array([0, 0]), - } - image = ops.ones(shape=(height, width, 3)) - bounding_boxes = utils.clip_to_image( - bounding_boxes, bounding_box_format="xyxy", images=image - ) - - ( - self.assertAllEqual( - bounding_boxes["boxes"], - np.array([[-1, -1, -1, -1], [100, 100, 256, 256]]), - ), - ) - self.assertAllEqual( - bounding_boxes["classes"], - np.array([-1, 0]), - ) - - def test_clip_to_image_filters_fully_out_bounding_boxes_negative_area(self): - # Test xyxy format unbatched - height = 256 - width = 256 - bounding_boxes = { - "boxes": np.array([[110, 120, 100, 100], [100, 100, 300, 300]]), - "classes": np.array([0, 0]), - } - image = ops.ones(shape=(height, width, 3)) - bounding_boxes = utils.clip_to_image( - bounding_boxes, bounding_box_format="xyxy", images=image - ) - self.assertAllEqual( - bounding_boxes["boxes"], - np.array( - [ - [ - -1, - -1, - -1, - -1, - ], - [ - 100, - 100, - 256, - 256, - ], - ] - ), - ) - self.assertAllEqual( - bounding_boxes["classes"], - np.array([-1, 0]), - ) - - def test_clip_to_image_filters_nans(self): - # Test xyxy format unbatched - height = 256 - width = 256 - bounding_boxes = { - "boxes": np.array( - [[0, float("NaN"), 100, 100], [100, 100, 300, 300]] - ), - "classes": np.array([0, 0]), - } - image = ops.ones(shape=(height, width, 3)) - bounding_boxes = utils.clip_to_image( - bounding_boxes, bounding_box_format="xyxy", images=image - ) - self.assertAllEqual( - bounding_boxes["boxes"], - np.array( - [ - [ - -1, - -1, - -1, - -1, - ], - [ - 100, - 100, - 256, - 256, - ], - ] - ), - ) - self.assertAllEqual( - bounding_boxes["classes"], - np.array([-1, 0]), - ) - - def test_is_relative_util(self): - self.assertTrue(utils.is_relative("rel_xyxy")) - self.assertFalse(utils.is_relative("xyxy")) - - with self.assertRaises(ValueError): - _ = utils.is_relative("bad_format") - - def test_as_relative_util(self): - self.assertEqual(utils.as_relative("yxyx"), "rel_yxyx") - self.assertEqual(utils.as_relative("rel_xywh"), "rel_xywh") diff --git a/keras_hub/src/bounding_box/validate_format.py b/keras_hub/src/bounding_box/validate_format.py deleted file mode 100644 index 8680dbb693..0000000000 --- a/keras_hub/src/bounding_box/validate_format.py +++ /dev/null @@ -1,85 +0,0 @@ -from keras_hub.src.api_export import keras_hub_export - -try: - import tensorflow as tf -except ImportError: - tf = None - - -@keras_hub_export("keras_hub.bounding_box.validate_format") -def validate_format(bounding_boxes, variable_name="bounding_boxes"): - """validates that a given set of bounding boxes complies with KerasHub - format. - - For a set of bounding boxes to be valid it must satisfy the following - conditions: - - `bounding_boxes` must be a dictionary - - contains keys `"boxes"` and `"classes"` - - each entry must have matching first two dimensions; representing the batch - axis and the number of boxes per image axis. - - either both `"boxes"` and `"classes"` are batched, or both are unbatched. - - Additionally, one of the following must be satisfied: - - `"boxes"` and `"classes"` are both Ragged - - `"boxes"` and `"classes"` are both Dense - - `"boxes"` and `"classes"` are unbatched - - Args: - bounding_boxes: dictionary of bounding boxes according to KerasCV - format. - - Raises: - ValueError if any of the above conditions are not met - """ - if not isinstance(bounding_boxes, dict): - raise ValueError( - f"Expected `{variable_name}` to be a dictionary, got " - f"`{variable_name}={bounding_boxes}`." - ) - if not all([x in bounding_boxes for x in ["boxes", "classes"]]): - raise ValueError( - f"Expected `{variable_name}` to be a dictionary containing keys " - "`'classes'` and `'boxes'`. Got " - f"`{variable_name}.keys()={bounding_boxes.keys()}`." - ) - - boxes = bounding_boxes.get("boxes") - classes = bounding_boxes.get("classes") - info = {} - - is_batched = len(boxes.shape) == 3 - info["is_batched"] = is_batched - info["ragged"] = isinstance(boxes, tf.RaggedTensor) - - if not is_batched: - if boxes.shape[:1] != classes.shape[:1]: - raise ValueError( - "Expected `boxes` and `classes` to have matching dimensions " - "on the first axis when operating in unbatched mode. Got " - f"`boxes.shape={boxes.shape}`, `classes.shape={classes.shape}`." - ) - - info["classes_one_hot"] = len(classes.shape) == 2 - # No Ragged checks needed in unbatched mode. - return info - - info["classes_one_hot"] = len(classes.shape) == 3 - - if isinstance(boxes, tf.RaggedTensor) != isinstance( - classes, tf.RaggedTensor - ): - raise ValueError( - "Either both `boxes` and `classes` " - "should be Ragged, or neither should be ragged." - f" Got `type(boxes)={type(boxes)}`, type(classes)={type(classes)}." - ) - - # Batched mode checks - if boxes.shape[:2] != classes.shape[:2]: - raise ValueError( - "Expected `boxes` and `classes` to have matching dimensions " - "on the first two axes when operating in batched mode. " - f"Got `boxes.shape={boxes.shape}`, `classes.shape={classes.shape}`." - ) - - return info diff --git a/keras_hub/src/bounding_box/validate_format_test.py b/keras_hub/src/bounding_box/validate_format_test.py deleted file mode 100644 index e2025e290a..0000000000 --- a/keras_hub/src/bounding_box/validate_format_test.py +++ /dev/null @@ -1,34 +0,0 @@ -import numpy as np - -from keras_hub.src.bounding_box import validate_format -from keras_hub.src.tests.test_case import TestCase - - -class ValidateTest(TestCase): - def test_raises_nondict(self): - with self.assertRaisesRegex( - ValueError, "Expected `bounding_boxes` to be a dictionary, got " - ): - validate_format.validate_format(np.ones((4, 3, 6))) - - def test_mismatch_dimensions(self): - with self.assertRaisesRegex( - ValueError, - "Expected `boxes` and `classes` to have matching dimensions", - ): - validate_format.validate_format( - {"boxes": np.ones((4, 3, 6)), "classes": np.ones((4, 6))} - ) - - def test_bad_keys(self): - with self.assertRaisesRegex(ValueError, "containing keys"): - validate_format.validate_format( - { - "box": [ - 1, - 2, - 3, - ], - "class": [1234], - } - ) diff --git a/keras_hub/src/layers/modeling/non_max_supression.py b/keras_hub/src/layers/modeling/non_max_supression.py index 70595492e5..207891ac9e 100644 --- a/keras_hub/src/layers/modeling/non_max_supression.py +++ b/keras_hub/src/layers/modeling/non_max_supression.py @@ -540,7 +540,7 @@ def mask_invalid_detections(bounding_boxes): ) boxes = bounding_boxes.get("boxes") - classes = bounding_boxes.get("labels") + labels = bounding_boxes.get("labels") confidence = bounding_boxes.get("confidence", None) num_detections = bounding_boxes.get("num_detections") @@ -551,7 +551,7 @@ def mask_invalid_detections(bounding_boxes): ) mask = mask < num_detections[:, None] - classes = ops.where(mask, classes, -ops.ones_like(classes)) + labels = ops.where(mask, labels, -ops.ones_like(labels)) if confidence is not None: confidence = ops.where(mask, confidence, -ops.ones_like(confidence)) @@ -564,7 +564,7 @@ def mask_invalid_detections(bounding_boxes): result = bounding_boxes.copy() result["boxes"] = boxes - result["labels"] = classes + result["labels"] = labels if confidence is not None: result["confidence"] = confidence diff --git a/keras_hub/src/models/image_object_detector_preprocessor.py b/keras_hub/src/models/image_object_detector_preprocessor.py index 2f89d216ef..6ff6e2c21d 100644 --- a/keras_hub/src/models/image_object_detector_preprocessor.py +++ b/keras_hub/src/models/image_object_detector_preprocessor.py @@ -53,5 +53,5 @@ def __init__( @preprocessing_function def call(self, x, y=None, sample_weight=None): if self.image_converter: - x = self.image_converter(x) + x, y = self.image_converter(x, y) return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_hub/src/models/retinanet/retinanet_image_converter.py b/keras_hub/src/models/retinanet/retinanet_image_converter.py index 6d26323a0a..279c5ef92d 100644 --- a/keras_hub/src/models/retinanet/retinanet_image_converter.py +++ b/keras_hub/src/models/retinanet/retinanet_image_converter.py @@ -1,3 +1,5 @@ +import keras + from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone @@ -10,25 +12,38 @@ class RetinaNetImageConverter(ImageConverter): def __init__( self, - image_size=None, - scale=None, - offset=None, + bounding_box_format, + pad_to_aspect_ratio=False, norm_mean=[0.485, 0.456, 0.406], norm_std=[0.229, 0.224, 0.225], **kwargs, ): super().__init__(**kwargs) - self.image_size = image_size - self.scale = scale - self.offset = offset + self.resizing = keras.layers.Resizing( + height=self.image_size[0] if self.image_size else None, + width=self.image_size[1] if self.image_size else None, + bounding_box_format=bounding_box_format, + crop_to_aspect_ratio=self.crop_to_aspect_ratio, + pad_to_aspect_ratio=pad_to_aspect_ratio, + interpolation=self.interpolation, + data_format=self.data_format, + dtype=self.dtype_policy, + name="resizing", + ) + + self.bounding_box_format = bounding_box_format + self.pad_to_aspect_ratio = pad_to_aspect_ratio self.norm_mean = norm_mean self.norm_std = norm_std - self.built = True @preprocessing_function - def call(self, inputs): - # TODO: https://github.com/keras-team/keras-hub/issues/1965 - x = inputs + def call(self, x, y=None, sample_weight=None): + if y is not None: + inputs = self.resizing({"images": x, "bounding_boxes": y}) + x = inputs["images"] + y = inputs["bounding_boxes"] + else: + x = self.resizing(x) # Rescaling Image if self.scale is not None: x = x * self._expand_non_channel_dims(self.scale, x) @@ -40,12 +55,14 @@ def call(self, inputs): if self.norm_std: x = x / self._expand_non_channel_dims(self.norm_std, x) - return x + return x, y def get_config(self): config = super().get_config() config.update( { + "bounding_box_format": self.bounding_box_format, + "pad_to_aspect_ratio": self.pad_to_aspect_ratio, "norm_mean": self.norm_mean, "norm_std": self.norm_std, } diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder.py b/keras_hub/src/models/retinanet/retinanet_label_encoder.py index bfb74c83da..886d86422d 100644 --- a/keras_hub/src/models/retinanet/retinanet_label_encoder.py +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder.py @@ -4,9 +4,6 @@ from keras import ops # TODO: https://github.com/keras-team/keras-hub/issues/1965 -from keras_hub.src.bounding_box.converters import convert_format -from keras_hub.src.bounding_box.converters import encode_box_to_deltas -from keras_hub.src.bounding_box.iou import compute_iou from keras_hub.src.layers.modeling.box_matcher import BoxMatcher from keras_hub.src.utils import tensor_utils @@ -113,7 +110,7 @@ def call(self, images, gt_boxes, gt_classes): "support unbatched inputs for the `images` argument. " f"Received `shape(images)={images_shape}`." ) - image_shape = images_shape[1:] + height, width, _ = images_shape[1:] if len(ops.shape(gt_classes)) == 2: gt_classes = ops.expand_dims(gt_classes, axis=-1) @@ -122,14 +119,14 @@ def call(self, images, gt_boxes, gt_classes): anchor_boxes = ops.concatenate(list(anchor_boxes.values()), axis=0) box_targets, class_targets = self._encode_sample( - gt_boxes, gt_classes, anchor_boxes, image_shape + gt_boxes, gt_classes, anchor_boxes, height, width ) box_targets = ops.reshape( box_targets, (-1, ops.shape(box_targets)[1], 4) ) return box_targets, class_targets - def _encode_sample(self, gt_boxes, gt_classes, anchor_boxes, image_shape): + def _encode_sample(self, gt_boxes, gt_classes, anchor_boxes, height, width): """Creates box and classification targets for a batched sample. Matches ground truth boxes to anchor boxes based on IOU. @@ -149,23 +146,25 @@ def _encode_sample(self, gt_boxes, gt_classes, anchor_boxes, image_shape): anchor_boxes: A Tensor with the shape `[total_anchors, 4]` representing all the anchor boxes for a given input image shape, where each anchor box is of the format `[x, y, width, height]`. - image_shape: Tuple indicating the image shape `[H, W, C]`. + height: int. + width: int. Returns: Encoded bounding boxes in the format of `center_yxwh` and corresponding labels for each encoded bounding box. """ - anchor_boxes = convert_format( + anchor_boxes = keras.utils.bounding_boxes.convert_format( anchor_boxes, source=self.anchor_generator.bounding_box_format, target=self.bounding_box_format, - image_shape=image_shape, + height=height, + width=width, ) - iou_matrix = compute_iou( + iou_matrix = keras.utils.bounding_boxes.compute_iou( anchor_boxes, gt_boxes, bounding_box_format=self.bounding_box_format, - image_shape=image_shape, + image_shape=(height, width, 3), ) matched_gt_idx, matched_vals = self.box_matcher(iou_matrix) @@ -179,14 +178,14 @@ def _encode_sample(self, gt_boxes, gt_classes, anchor_boxes, image_shape): matched_gt_boxes, (-1, ops.shape(matched_gt_boxes)[1], 4) ) - box_targets = encode_box_to_deltas( + box_targets = keras.utils.bounding_boxes.encode_box_to_deltas( anchors=anchor_boxes, boxes=matched_gt_boxes, anchor_format=self.bounding_box_format, box_format=self.bounding_box_format, encoding_format=self.encoding_format, variance=self.box_variance, - image_shape=image_shape, + image_shape=(height, width, 3), ) matched_gt_cls_ids = tensor_utils.target_gather( diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py b/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py index d05bf5a99a..ca4f151309 100644 --- a/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py @@ -1,7 +1,7 @@ import numpy as np from keras import ops -from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator +from keras_hub.src.layers.modeling.anchor_generator import AnchorGenerator from keras_hub.src.models.retinanet.retinanet_label_encoder import ( RetinaNetLabelEncoder, ) diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index d8b9e17304..d9523a3a58 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -2,10 +2,6 @@ from keras import ops from keras_hub.src.api_export import keras_hub_export - -# TODO: https://github.com/keras-team/keras-hub/issues/1965 -from keras_hub.src.bounding_box.converters import convert_format -from keras_hub.src.bounding_box.converters import decode_deltas_to_boxes from keras_hub.src.layers.modeling.anchor_generator import AnchorGenerator from keras_hub.src.layers.modeling.non_max_supression import NonMaxSuppression from keras_hub.src.models.image_object_detector import ImageObjectDetector @@ -204,17 +200,19 @@ def __init__( ) def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): - y_for_label_encoder = convert_format( + _, height, width, _ = keras.ops.shape(x) + y_for_label_encoder = keras.utils.bounding_boxes.convert_format( y, source=self.bounding_box_format, target=self.label_encoder.bounding_box_format, - images=x, + height=height, + width=width, ) - boxes, classes = self.label_encoder( + boxes, labels = self.label_encoder( images=x, gt_boxes=y_for_label_encoder["boxes"], - gt_classes=y_for_label_encoder["classes"], + gt_classes=y_for_label_encoder["labels"], ) box_pred = y_pred["bbox_regression"] @@ -242,11 +240,11 @@ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): ) cls_labels = ops.one_hot( - ops.cast(classes, "int32"), self.num_classes, dtype="float32" + ops.cast(labels, "int32"), self.num_classes, dtype="float32" ) - positive_mask = ops.cast(ops.greater(classes, -1.0), dtype="float32") + positive_mask = ops.cast(ops.greater(labels, -1.0), dtype="float32") normalizer = ops.sum(positive_mask) - cls_weights = ops.cast(ops.not_equal(classes, -2.0), dtype="float32") + cls_weights = ops.cast(ops.not_equal(labels, -2.0), dtype="float32") cls_weights /= normalizer box_weights = positive_mask / normalizer @@ -306,32 +304,32 @@ def decode_predictions(self, predictions, data): images, _ = data else: images = data - image_shape = ops.shape(images)[1:] + height, width, channels = ops.shape(images)[1:] anchor_boxes = self.anchor_generator(images) anchor_boxes = ops.concatenate(list(anchor_boxes.values()), axis=0) - box_pred = decode_deltas_to_boxes( + box_pred = keras.utils.bounding_boxes.decode_deltas_to_boxes( anchors=anchor_boxes, boxes_delta=box_pred, encoded_format="center_xywh", anchor_format=self.anchor_generator.bounding_box_format, box_format=self.bounding_box_format, - image_shape=image_shape, + image_shape=(height, width, channels), ) # box_pred is now in "self.bounding_box_format" format - box_pred = convert_format( + box_pred = keras.utils.bounding_boxes.convert_format( box_pred, source=self.bounding_box_format, target=self.prediction_decoder.bounding_box_format, - image_shape=image_shape, - ) - y_pred = self.prediction_decoder( - box_pred, cls_pred, image_shape=image_shape + height=height, + width=width, ) - y_pred["boxes"] = convert_format( + y_pred = self.prediction_decoder(box_pred, cls_pred, images=images) + y_pred["boxes"] = keras.utils.bounding_boxes.convert_format( y_pred["boxes"], source=self.prediction_decoder.bounding_box_format, target=self.bounding_box_format, - image_shape=image_shape, + height=height, + width=width, ) return y_pred diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py index 53d7461bb1..38e917c4a7 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py @@ -1,8 +1,8 @@ import numpy as np import pytest +from keras_hub.src.layers.modeling.anchor_generator import AnchorGenerator from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone -from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone from keras_hub.src.models.retinanet.retinanet_image_converter import ( RetinaNetImageConverter, @@ -53,7 +53,9 @@ def setUp(self): bounding_box_format="yxyx", anchor_generator=anchor_generator ) - image_converter = RetinaNetImageConverter(scale=1 / 255.0) + image_converter = RetinaNetImageConverter( + bounding_box_format="yxyx", scale=1 / 255.0, image_size=(800, 800) + ) preprocessor = RetinaNetObjectDetectorPreprocessor( image_converter=image_converter @@ -76,7 +78,7 @@ def setUp(self): "boxes": np.array( [[[20.0, 10.0, 12.0, 11.0], [30.0, 20.0, 40.0, 12.0]]] ), - "classes": np.array([[0, 2]]), + "labels": np.array([[0, 2]]), } self.train_data = (self.images, self.labels) @@ -87,7 +89,7 @@ def test_detection_basics(self): train_data=self.train_data, expected_output_shape={ "boxes": (1, 100, 4), - "classes": (1, 100), + "labels": (1, 100), "confidence": (1, 100), "num_detections": (1,), }, From d4c006bf762109eebc118b63123fc6f81c3ff2f8 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 8 Jan 2025 13:51:27 -0800 Subject: [PATCH 33/33] remove api import --- keras_hub/api/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras_hub/api/__init__.py b/keras_hub/api/__init__.py index 7e44f01381..fa8636ab70 100644 --- a/keras_hub/api/__init__.py +++ b/keras_hub/api/__init__.py @@ -4,7 +4,6 @@ since your modifications would be overwritten. """ -from keras_hub.api import bounding_box from keras_hub.api import layers from keras_hub.api import metrics from keras_hub.api import models