From afd73c7c90b07636ae04808bdb3c94bad30a1111 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Wed, 14 Aug 2024 23:24:18 +0000 Subject: [PATCH 1/6] Add DenseNet --- keras_nlp/api/models/__init__.py | 4 + keras_nlp/src/models/densenet/__init__.py | 13 ++ .../src/models/densenet/densenet_backbone.py | 206 ++++++++++++++++++ .../models/densenet/densenet_backbone_test.py | 48 ++++ .../densenet/densenet_image_classifier.py | 124 +++++++++++ .../densenet_image_classifier_test.py | 63 ++++++ 6 files changed, 458 insertions(+) create mode 100644 keras_nlp/src/models/densenet/__init__.py create mode 100644 keras_nlp/src/models/densenet/densenet_backbone.py create mode 100644 keras_nlp/src/models/densenet/densenet_backbone_test.py create mode 100644 keras_nlp/src/models/densenet/densenet_image_classifier.py create mode 100644 keras_nlp/src/models/densenet/densenet_image_classifier_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 783cfd5087..0f76a39577 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -68,6 +68,10 @@ from keras_nlp.src.models.deberta_v3.deberta_v3_tokenizer import ( DebertaV3Tokenizer, ) +from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone +from keras_nlp.src.models.densenet.densenet_image_classifier import ( + DensekNetImageClassifier, +) from keras_nlp.src.models.distil_bert.distil_bert_backbone import ( DistilBertBackbone, ) diff --git a/keras_nlp/src/models/densenet/__init__.py b/keras_nlp/src/models/densenet/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/densenet/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/densenet/densenet_backbone.py b/keras_nlp/src/models/densenet/densenet_backbone.py new file mode 100644 index 0000000000..bd2357d1f0 --- /dev/null +++ b/keras_nlp/src/models/densenet/densenet_backbone.py @@ -0,0 +1,206 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone + +BN_AXIS = 3 +BN_EPSILON = 1.001e-5 + + +@keras_nlp_export("keras_nlp.models.DenseNetBackbone") +class DenseNetBackbone(Backbone): + """Instantiates the DenseNet architecture. + + Args: + stackwise_num_repeats: list of ints, number of repeated convolutional + blocks per dense block. + include_rescaling: bool, whether to rescale the inputs. If set + to `True`, inputs will be passed through a `Rescaling(1/255.0)` + layer. + input_image_shape: optional shape tuple, defaults to (224, 224, 3). + compression_ratio: float, compression rate at transition layers, + defaults to 0.5. + growth_rate: int, number of filters added by each dense block, + defaults to 32 + + Examples: + ```python + input_data = np.ones(shape=(8, 224, 224, 3)) + + # Pretrained backbone + model = keras_nlp.models.DenseNetBackbone.from_preset("densenet121_imagenet") + model(input_data) + + # Randomly initialized backbone with a custom config + model = keras_nlp.models.DenseNetBackbone( + stackwise_num_repeats=[6, 12, 24, 16], + include_rescaling=False, + ) + model(input_data) + ``` + """ + + def __init__( + self, + *, + stackwise_num_repeats, + include_rescaling, + input_image_shape=(224, 224, 3), + compression_ratio=0.5, + growth_rate=32, + **kwargs, + ): + # === Functional Model === + image_input = keras.layers.Input(shape=input_image_shape) + + x = image_input + if include_rescaling: + x = keras.layers.Rescaling(1 / 255.0)(x) + + x = keras.layers.Conv2D( + 64, 7, strides=2, use_bias=False, padding="same", name="conv1_conv" + )(x) + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name="conv1_bn" + )(x) + x = keras.layers.Activation("relu", name="conv1_relu")(x) + x = keras.layers.MaxPooling2D( + 3, strides=2, padding="same", name="pool1" + )(x) + + for stack_index in range(len(stackwise_num_repeats) - 1): + index = stack_index + 2 + x = apply_dense_block( + x, + stackwise_num_repeats[stack_index], + growth_rate, + name=f"conv{index}", + ) + x = apply_transition_block( + x, compression_ratio, name=f"pool{index}" + ) + + x = apply_dense_block( + x, + stackwise_num_repeats[-1], + growth_rate, + name=f"conv{len(stackwise_num_repeats) + 1}", + ) + + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name="bn" + )(x) + x = keras.layers.Activation("relu", name="relu")(x) + + super().__init__(inputs=image_input, outputs=x, **kwargs) + + # === Config === + self.stackwise_num_repeats = stackwise_num_repeats + self.include_rescaling = include_rescaling + self.compression_ratio = compression_ratio + self.growth_rate = growth_rate + self.input_image_shape = input_image_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "stackwise_num_repeats": self.stackwise_num_repeats, + "include_rescaling": self.include_rescaling, + "compression_ratio": self.compression_ratio, + "growth_rate": self.growth_rate, + "input_image_shape": self.input_image_shape, + } + ) + return config + + +def apply_dense_block(x, num_repeats, growth_rate, name=None): + """A dense block. + + Args: + x: input tensor. + num_repeats: int, number of repeated convolutional blocks. + growth_rate: int, number of filters added by each dense block. + name: string, block label. + """ + if name is None: + name = f"dense_block_{keras.backend.get_uid('dense_block')}" + + for i in range(num_repeats): + x = apply_conv_block(x, growth_rate, name=f"{name}_block_{i}") + return x + + +def apply_transition_block(x, compression_ratio, name=None): + """A transition block. + + Args: + x: input tensor. + compression_ratio: float, compression rate at transition layers. + name: string, block label. + """ + if name is None: + name = f"transition_block_{keras.backend.get_uid('transition_block')}" + + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_bn" + )(x) + x = keras.layers.Activation("relu", name=f"{name}_relu")(x) + x = keras.layers.Conv2D( + int(x.shape[BN_AXIS] * compression_ratio), + 1, + use_bias=False, + name=f"{name}_conv", + )(x) + x = keras.layers.AveragePooling2D(2, strides=2, name=f"{name}_pool")(x) + return x + + +def apply_conv_block(x, growth_rate, name=None): + """A building block for a dense block. + + Args: + x: input tensor. + growth_rate: int, number of filters added by each dense block. + name: string, block label. + """ + if name is None: + name = f"conv_block_{keras.backend.get_uid('conv_block')}" + + shortcut = x + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_0_bn" + )(x) + x = keras.layers.Activation("relu", name=f"{name}_0_relu")(x) + x = keras.layers.Conv2D( + 4 * growth_rate, 1, use_bias=False, name=f"{name}_1_conv" + )(x) + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_1_bn" + )(x) + x = keras.layers.Activation("relu", name=f"{name}_1_relu")(x) + x = keras.layers.Conv2D( + growth_rate, + 3, + padding="same", + use_bias=False, + name=f"{name}_2_conv", + )(x) + x = keras.layers.Concatenate(axis=BN_AXIS, name=f"{name}_concat")( + [shortcut, x] + ) + return x diff --git a/keras_nlp/src/models/densenet/densenet_backbone_test.py b/keras_nlp/src/models/densenet/densenet_backbone_test.py new file mode 100644 index 0000000000..f0f8dac875 --- /dev/null +++ b/keras_nlp/src/models/densenet/densenet_backbone_test.py @@ -0,0 +1,48 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class DenseNetBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_num_repeats": [6, 12, 24, 16], + "include_rescaling": True, + "compression_ratio": 0.5, + "growth_rate": 32, + "input_image_shape": (224, 224, 3), + } + self.input_data = np.ones((2, 224, 224, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_backbone_test( + cls=DenseNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 7, 7, 1024), + run_mixed_precision_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DenseNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/densenet/densenet_image_classifier.py b/keras_nlp/src/models/densenet/densenet_image_classifier.py new file mode 100644 index 0000000000..0158982f79 --- /dev/null +++ b/keras_nlp/src/models/densenet/densenet_image_classifier.py @@ -0,0 +1,124 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone +from keras_nlp.src.models.image_classifier import ImageClassifier + + +@keras_nlp_export("keras_nlp.models.DensekNetImageClassifier") +class DensekNetImageClassifier(ImageClassifier): + """DensekNet image classifier task model. + Args: + backbone: A `keras_nlp.models.DenseNetBackbone` instance. + num_classes: int. The number of classes to predict. + activation: `None`, str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + Examples: + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + classifier = keras_nlp.models.DensekNetImageClassifier.from_preset( + "densenet121_imagenet") + classifier.predict(images) + ``` + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.DensekNetImageClassifier.from_preset( + "densenet121_imagenet") + classifier.fit(x=images, y=labels, batch_size=2) + ``` + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_nlp.models.DensekNetImageClassifier.from_preset( + "densenet121_imagenet") + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + Custom backbone. + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + backbone = keras_nlp.models.DenseNetBackbone( + stackwise_num_filters=[128, 256, 512, 1024], + stackwise_depth=[3, 9, 9, 3], + include_rescaling=False, + block_type="basic_block", + input_image_shape = (224, 224, 3), + ) + classifier = keras_nlp.models.DensekNetImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = DenseNetBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/densenet/densenet_image_classifier_test.py b/keras_nlp/src/models/densenet/densenet_image_classifier_test.py new file mode 100644 index 0000000000..2946f0d05b --- /dev/null +++ b/keras_nlp/src/models/densenet/densenet_image_classifier_test.py @@ -0,0 +1,63 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone +from keras_nlp.src.models.densenet.densenet_image_classifier import ( + DensekNetImageClassifier, +) +from keras_nlp.src.tests.test_case import TestCase + + +class DensekNetImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 16, 16, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = DenseNetBackbone( + stackwise_num_repeats=[6, 12, 24, 16], + include_rescaling=True, + compression_ratio=0.5, + growth_rate=32, + input_image_shape=(224, 224, 3), + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=DensekNetImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DensekNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) From 92c1b16d8d2a3121161618dc0ed83ece938b2ea9 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Thu, 15 Aug 2024 17:09:50 +0000 Subject: [PATCH 2/6] fix testcase --- keras_nlp/src/models/densenet/densenet_image_classifier_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/src/models/densenet/densenet_image_classifier_test.py b/keras_nlp/src/models/densenet/densenet_image_classifier_test.py index 2946f0d05b..a8ade13cda 100644 --- a/keras_nlp/src/models/densenet/densenet_image_classifier_test.py +++ b/keras_nlp/src/models/densenet/densenet_image_classifier_test.py @@ -31,7 +31,7 @@ def setUp(self): include_rescaling=True, compression_ratio=0.5, growth_rate=32, - input_image_shape=(224, 224, 3), + input_image_shape=(16, 16, 3), ) self.init_kwargs = { "backbone": self.backbone, From e38a0941783db7714d0c1e8560f0ba799cbe04f1 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Thu, 15 Aug 2024 21:39:18 +0000 Subject: [PATCH 3/6] address comments --- .../src/models/densenet/densenet_backbone.py | 10 +++++++--- .../densenet/densenet_image_classifier.py | 20 ++++++++++++------- .../densenet_image_classifier_test.py | 12 +++++------ 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/keras_nlp/src/models/densenet/densenet_backbone.py b/keras_nlp/src/models/densenet/densenet_backbone.py index bd2357d1f0..d11e8ffef5 100644 --- a/keras_nlp/src/models/densenet/densenet_backbone.py +++ b/keras_nlp/src/models/densenet/densenet_backbone.py @@ -23,13 +23,18 @@ @keras_nlp_export("keras_nlp.models.DenseNetBackbone") class DenseNetBackbone(Backbone): """Instantiates the DenseNet architecture. + + This class implements a DenseNet backbone as described in + [Densely Connected Convolutional Networks (CVPR 2017)]( + https://arxiv.org/abs/1608.06993 + ). Args: stackwise_num_repeats: list of ints, number of repeated convolutional blocks per dense block. include_rescaling: bool, whether to rescale the inputs. If set to `True`, inputs will be passed through a `Rescaling(1/255.0)` - layer. + layer. Defaults to `True`. input_image_shape: optional shape tuple, defaults to (224, 224, 3). compression_ratio: float, compression rate at transition layers, defaults to 0.5. @@ -55,9 +60,8 @@ class DenseNetBackbone(Backbone): def __init__( self, - *, stackwise_num_repeats, - include_rescaling, + include_rescaling=True, input_image_shape=(224, 224, 3), compression_ratio=0.5, growth_rate=32, diff --git a/keras_nlp/src/models/densenet/densenet_image_classifier.py b/keras_nlp/src/models/densenet/densenet_image_classifier.py index 0158982f79..1bae74ea98 100644 --- a/keras_nlp/src/models/densenet/densenet_image_classifier.py +++ b/keras_nlp/src/models/densenet/densenet_image_classifier.py @@ -18,9 +18,10 @@ from keras_nlp.src.models.image_classifier import ImageClassifier -@keras_nlp_export("keras_nlp.models.DensekNetImageClassifier") -class DensekNetImageClassifier(ImageClassifier): - """DensekNet image classifier task model. +@keras_nlp_export("keras_nlp.models.DenseNetImageClassifier") +class DenseNetImageClassifier(ImageClassifier): + """DenseNet image classifier task model. + Args: backbone: A `keras_nlp.models.DenseNetBackbone` instance. num_classes: int. The number of classes to predict. @@ -31,27 +32,31 @@ class DensekNetImageClassifier(ImageClassifier): where `x` is a tensor and `y` is a integer from `[0, num_classes)`. All `ImageClassifier` tasks include a `from_preset()` constructor which can be used to load a pre-trained config and weights. + Examples: + Call `predict()` to run inference. ```python # Load preset and train images = np.ones((2, 224, 224, 3), dtype="float32") - classifier = keras_nlp.models.DensekNetImageClassifier.from_preset( + classifier = keras_nlp.models.DenseNetImageClassifier.from_preset( "densenet121_imagenet") classifier.predict(images) ``` + Call `fit()` on a single batch. ```python # Load preset and train images = np.ones((2, 224, 224, 3), dtype="float32") labels = [0, 3] - classifier = keras_nlp.models.DensekNetImageClassifier.from_preset( + classifier = keras_nlp.models.DenseNetImageClassifier.from_preset( "densenet121_imagenet") classifier.fit(x=images, y=labels, batch_size=2) ``` + Call `fit()` with custom loss, optimizer and backbone. ```python - classifier = keras_nlp.models.DensekNetImageClassifier.from_preset( + classifier = keras_nlp.models.DenseNetImageClassifier.from_preset( "densenet121_imagenet") classifier.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), @@ -60,6 +65,7 @@ class DensekNetImageClassifier(ImageClassifier): classifier.backbone.trainable = False classifier.fit(x=images, y=labels, batch_size=2) ``` + Custom backbone. ```python images = np.ones((2, 224, 224, 3), dtype="float32") @@ -71,7 +77,7 @@ class DensekNetImageClassifier(ImageClassifier): block_type="basic_block", input_image_shape = (224, 224, 3), ) - classifier = keras_nlp.models.DensekNetImageClassifier( + classifier = keras_nlp.models.DenseNetImageClassifier( backbone=backbone, num_classes=4, ) diff --git a/keras_nlp/src/models/densenet/densenet_image_classifier_test.py b/keras_nlp/src/models/densenet/densenet_image_classifier_test.py index a8ade13cda..60d77d489c 100644 --- a/keras_nlp/src/models/densenet/densenet_image_classifier_test.py +++ b/keras_nlp/src/models/densenet/densenet_image_classifier_test.py @@ -16,22 +16,22 @@ from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone from keras_nlp.src.models.densenet.densenet_image_classifier import ( - DensekNetImageClassifier, + DenseNetImageClassifier, ) from keras_nlp.src.tests.test_case import TestCase -class DensekNetImageClassifierTest(TestCase): +class DenseNetImageClassifierTest(TestCase): def setUp(self): # Setup model. - self.images = np.ones((2, 16, 16, 3), dtype="float32") + self.images = np.ones((2, 224, 224, 3), dtype="float32") self.labels = [0, 3] self.backbone = DenseNetBackbone( stackwise_num_repeats=[6, 12, 24, 16], include_rescaling=True, compression_ratio=0.5, growth_rate=32, - input_image_shape=(16, 16, 3), + input_image_shape=(224, 224, 3), ) self.init_kwargs = { "backbone": self.backbone, @@ -48,7 +48,7 @@ def test_classifier_basics(self): reason="TODO: enable after preprocessor flow is figured out" ) self.run_task_test( - cls=DensekNetImageClassifier, + cls=DenseNetImageClassifier, init_kwargs=self.init_kwargs, train_data=self.train_data, expected_output_shape=(2, 2), @@ -57,7 +57,7 @@ def test_classifier_basics(self): @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( - cls=DensekNetImageClassifier, + cls=DenseNetImageClassifier, init_kwargs=self.init_kwargs, input_data=self.images, ) From 4ffe49b37fbf33d9109715376b20285d0852fdd8 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Thu, 15 Aug 2024 21:44:03 +0000 Subject: [PATCH 4/6] nit --- keras_nlp/api/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 0f76a39577..1a0dc63681 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -70,7 +70,7 @@ ) from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone from keras_nlp.src.models.densenet.densenet_image_classifier import ( - DensekNetImageClassifier, + DenseNetImageClassifier, ) from keras_nlp.src.models.distil_bert.distil_bert_backbone import ( DistilBertBackbone, From 19344ab4c0aab54d17920b525ee62e6744307721 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Thu, 15 Aug 2024 21:51:50 +0000 Subject: [PATCH 5/6] fix lint errors --- keras_nlp/src/models/densenet/densenet_backbone.py | 6 +++--- .../src/models/densenet/densenet_image_classifier.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/keras_nlp/src/models/densenet/densenet_backbone.py b/keras_nlp/src/models/densenet/densenet_backbone.py index d11e8ffef5..8456fbcee6 100644 --- a/keras_nlp/src/models/densenet/densenet_backbone.py +++ b/keras_nlp/src/models/densenet/densenet_backbone.py @@ -23,10 +23,10 @@ @keras_nlp_export("keras_nlp.models.DenseNetBackbone") class DenseNetBackbone(Backbone): """Instantiates the DenseNet architecture. - - This class implements a DenseNet backbone as described in + + This class implements a DenseNet backbone as described in [Densely Connected Convolutional Networks (CVPR 2017)]( - https://arxiv.org/abs/1608.06993 + https://arxiv.org/abs/1608.06993 ). Args: diff --git a/keras_nlp/src/models/densenet/densenet_image_classifier.py b/keras_nlp/src/models/densenet/densenet_image_classifier.py index 1bae74ea98..dbc71bef25 100644 --- a/keras_nlp/src/models/densenet/densenet_image_classifier.py +++ b/keras_nlp/src/models/densenet/densenet_image_classifier.py @@ -21,7 +21,7 @@ @keras_nlp_export("keras_nlp.models.DenseNetImageClassifier") class DenseNetImageClassifier(ImageClassifier): """DenseNet image classifier task model. - + Args: backbone: A `keras_nlp.models.DenseNetBackbone` instance. num_classes: int. The number of classes to predict. @@ -32,9 +32,9 @@ class DenseNetImageClassifier(ImageClassifier): where `x` is a tensor and `y` is a integer from `[0, num_classes)`. All `ImageClassifier` tasks include a `from_preset()` constructor which can be used to load a pre-trained config and weights. - + Examples: - + Call `predict()` to run inference. ```python # Load preset and train @@ -43,7 +43,7 @@ class DenseNetImageClassifier(ImageClassifier): "densenet121_imagenet") classifier.predict(images) ``` - + Call `fit()` on a single batch. ```python # Load preset and train @@ -53,7 +53,7 @@ class DenseNetImageClassifier(ImageClassifier): "densenet121_imagenet") classifier.fit(x=images, y=labels, batch_size=2) ``` - + Call `fit()` with custom loss, optimizer and backbone. ```python classifier = keras_nlp.models.DenseNetImageClassifier.from_preset( @@ -65,7 +65,7 @@ class DenseNetImageClassifier(ImageClassifier): classifier.backbone.trainable = False classifier.fit(x=images, y=labels, batch_size=2) ``` - + Custom backbone. ```python images = np.ones((2, 224, 224, 3), dtype="float32") From b7247cb5a931cdc71147fb8d82931362900d4b37 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Thu, 15 Aug 2024 22:16:04 +0000 Subject: [PATCH 6/6] move description --- .../src/models/densenet/densenet_image_classifier.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/keras_nlp/src/models/densenet/densenet_image_classifier.py b/keras_nlp/src/models/densenet/densenet_image_classifier.py index dbc71bef25..395e8f754d 100644 --- a/keras_nlp/src/models/densenet/densenet_image_classifier.py +++ b/keras_nlp/src/models/densenet/densenet_image_classifier.py @@ -22,16 +22,17 @@ class DenseNetImageClassifier(ImageClassifier): """DenseNet image classifier task model. + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + Args: backbone: A `keras_nlp.models.DenseNetBackbone` instance. num_classes: int. The number of classes to predict. activation: `None`, str or callable. The activation function to use on the `Dense` layer. Set `activation=None` to return the output logits. Defaults to `"softmax"`. - To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` - where `x` is a tensor and `y` is a integer from `[0, num_classes)`. - All `ImageClassifier` tasks include a `from_preset()` constructor which can - be used to load a pre-trained config and weights. Examples: