Skip to content

Add the ResNet_vd backbone #1766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 28, 2024
Merged

Conversation

gowthamkpr
Copy link
Collaborator

This adds the ResNet_vd backbone to prepare for differential binarization model.

Architecture is adopted from PaddleOCR (https://github.com/PaddlePaddle/PaddleOCR/blob/main/ppocr/modeling/backbones/det_resnet_vd.py).

This is the first step of #1739:

  • Aliases have been removed from ResNet_vd
  • Model has been moved out of the subdirectory
  • Tests (from keras_cv) have been added

Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR Gowtham! Looks good overall.
With KerasHub if the model has multiple versions with small changes in architecture, they will be added as one model and the configuration will be handled in the args.
We have this PR for resnet - #1765
The architecture seems similar(please let me know if this is not the case and this needs to be a separate model), can we combine ResnetVD into that one?

model(self.input_batch)

@pytest.mark.large # Saving is slow, so mark these large.
def test_saved_model(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



class ResNetVdBackboneTest(TestCase):
def setUp(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyashreepathihalli
Copy link
Collaborator

divyashreepathihalli commented Aug 9, 2024

Also to format the code you can use shell/api_gen.sh, shell/format.sh and shell/lint.sh at kerasnlp root dir. You can also run the tests locally with pytest keras_nlp/src/models/resnet_vd

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ResNet_vd backbone model.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't do any module level docstrings like this


@keras_nlp_export("keras_nlp.models.ResNetVdBackbone")
class ResNetVdBackbone(Backbone):
"""Instantiates the ResNet_vd architecture.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to write this more in the style of other backbone we have. No reference block. See BertBackbone as a good template.

stack. Use "basic_block" for ResNet18 and ResNet34.

Examples:
input_data = tf.ones(shape=(8, 224, 224, 3))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be inside the code block I think.

Comment on lines 175 to 184
@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return copy.deepcopy(backbone_presets)

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include
weights."""
return copy.deepcopy(backbone_presets_with_weights)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't need to do this anymore

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"official_name": "ResNetVd",
"path": "resnet_vd",
},
# "kaggle_handle": "kaggle://keras/resnetv1/keras/resnet18/2",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused at this. Is resnet vd just the same structure as resnet?

@gowthamkpr
Copy link
Collaborator Author

I've changed the code to reuse ResNetBackbone (#1765) and to set the variant based on a parameter. For this, I replaced use_pre_activation with version, which can be either "v1", "v2" or "vd". Within the code, I retained use_pre_activation to retain more concise branching. This can be changed, of course.

Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Gowtham! I am a little confused about the added conv2d layer. A numerics verifying colab will be helpful.
Lets update the arg as well.

x = layers.BatchNormalization(
axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name="conv1_bn"
)(x)
x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x)
x = layers.Conv2D(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little confused about the architecture here. The VD version was supposed to have only different pooling layers. Here I am seeing that you are adding an additional conv2d layer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A colab verifying the numerics for this would be helpful the verify architecture correctness

Copy link
Collaborator Author

@gowthamkpr gowthamkpr Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modifications compared to plain ResNetV1 are

  • Average pooling is used in shortcut connections instead of Conv2D strides
  • The model's first convolutional layer is replaced with three subsequent ones
  • Stride configuration in residual blocks is slightly different

This paper has some good visualizations of the ResNet_vd architecture.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can look into creating a notebook to verify numerics. Since we are primarily adding this for the differential binarization model, I would check numerics based on the pretrained differential binarization model from PaddleOCR unless you think we should compare numerics for the classifier network as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can verify the numerics just for the resnetvd backbone with just random weights.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks I will take a look at the paper

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think overall, we want to avoid something like this, where a single flag is used for a proxy for a lot of magic numbers.

We want the magic numbers to live in the config for the model. The 64 that was here before for the single conv layer is probably better written as stackwise_num_filters[0] (but at least we were throwing in stackwise_num_filters[0] != 64 in that case).

@@ -86,7 +86,7 @@ class ResNetBackbone(Backbone):
stackwise_num_blocks=[2, 2, 2],
stackwise_num_strides=[1, 2, 2],
block_type="basic_block",
use_pre_activation=True,
version="v2",
pooling="avg",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets keep the use_pre_activation arg. instead of version lets add arg something like use_vd_pooling or something like that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can change this.
Just a note - with two booleans, we convey that you can set use_pre_activation = use_vd_pooling = True. Since both architecture variants introduce subtle modifications, it will not be entirely clear from the interface what the outcome would be. My personal preference would be having it inherently mutually exclusive.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we should have two different args!

@@ -209,7 +258,7 @@ def __init__(
self.stackwise_num_blocks = stackwise_num_blocks
self.stackwise_num_strides = stackwise_num_strides
self.block_type = block_type
self.use_pre_activation = use_pre_activation
self.version = version
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above for changing arg name for version

@@ -220,7 +269,7 @@ def get_config(self):
"stackwise_num_blocks": self.stackwise_num_blocks,
"stackwise_num_strides": self.stackwise_num_strides,
"block_type": self.block_type,
"use_pre_activation": self.use_pre_activation,
"version": self.version,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above for changing arg name for version

@james77777778
Copy link
Collaborator

Hi @divyashreepathihalli @gowthamkpr
Can we wait for #1769 ?
The implementation of ResNetBackbone has been changed in that PR and the numerical differences have been verified.

@gowthamkpr
Copy link
Collaborator Author

gowthamkpr commented Aug 19, 2024

I've updated the PR for the newest version of the code. Numerical agreement has been verified in this notebook and summarized observed numerical deviations in this document. Observed maximum deviations from PaddleClas range between 5e-7 and 3e-6.

For verification, I have converted all ResNet_vd models from PaddleClas to Keras models. They report accuracy on ImageNet-val as shown in this figure from this link
image

The accuracy readings for the ssld models (shown in red, they use knowledge distillation) are actually really good compared to the values reported in the "ResNet strikes back" paper for timm models. This might be interesting for us.

@divyashreepathihalli
Copy link
Collaborator

divyashreepathihalli commented Aug 19, 2024

Excellent!! can we add all of these variants for our preset? Hongyu added a built in conversion from timm for most of the resnet models - ptal - #1769

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments on how we could clean up the config here.

@@ -133,7 +142,12 @@ def __init__(
'`block_type` must be either `"basic_block"` or '
f'`"bottleneck_block"`. Received block_type={block_type}.'
)
version = "v1" if not use_pre_activation else "v2"
if use_vd_pooling:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this if block is a bit of an anti pattern. Pre-existing though. We don't need layer names to reflect whether something is vd or v2 or v1, we can just do that via preset name. Maintaining something like this will be fragile.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The timm conversion of ResNet might fail after renaming the layers. I can fix it later if needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@james77777778 yes please check if we need a small fix after this lands!

x = layers.BatchNormalization(
axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name="conv1_bn"
)(x)
x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x)
x = layers.Conv2D(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think overall, we want to avoid something like this, where a single flag is used for a proxy for a lot of magic numbers.

We want the magic numbers to live in the config for the model. The 64 that was here before for the single conv layer is probably better written as stackwise_num_filters[0] (but at least we were throwing in stackwise_num_filters[0] != 64 in that case).

@@ -106,6 +114,7 @@ def __init__(
stackwise_num_strides,
block_type,
use_pre_activation=False,
use_vd_pooling=False,
Copy link
Member

@mattdangerw mattdangerw Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want this one flag that proxies a bunch of options. This is going to be hard for maintenance, and lead to confusing configs.

I think we might want to add a new block type, "vd_block", and two new args here. input_conv_filters=(64,) and input_conv_kernel_sizes=(7,). Or something like that.

Then the config for resnet vd would look like...

{
    ...,
    input_conv_filters=(32, 32, 64),
    input_conv_kernel_sizes=(3, 3, 3),
    block_type="vd_block",
}

Then you could implement the input conv layers a a loop, and implement the new vd block, by copying the basic_block with some minor changes.

WDYT? @james77777778 any tweaks as well?

Not sure I'm getting the exact args right, but at a high level, avoid hard coding magic numbers that define the structure of this model behind feature flags. Instead they should be directly in the config.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds good to me. It's clearer than using a bool/str arg.

@@ -155,21 +169,21 @@ def __init__(
# The padding between torch and tensorflow/jax differs when `strides>1`.
# Therefore, we need to manually pad the tensor.
x = layers.ZeroPadding2D(
3,
1 if use_vd_pooling else 3,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under the proposal above, this might be able to be written as input_conv_kernel_sizes[0] // 2 or something like that?

@@ -580,7 +678,12 @@ def apply_stack(
Output tensor for the stacked blocks.
"""
if name is None:
version = "v1" if not use_pre_activation else "v2"
if use_vd_pooling:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above I think we should remove this. The preset name is how we can track different in vd vs v2 vs v2. This will be hard to maintain.

@gowthamkpr
Copy link
Collaborator Author

Excellent!! can we add all of these variants for our preset? Hongyu added a built in conversion from timm for most of the resnet models - ptal - #1769

Yes. Model files are available from the linked notebook. Both keras-nlp and PaddleClas are subject to an Apache 2.0 licence.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice looking good! I think there might be an extraneous additions to the basic block to remove.

@@ -461,7 +527,295 @@ def apply_bottleneck_block(
use_bias=False,
dtype=dtype,
name=f"{name}_0_conv",
)(shortcut)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's up with this 50 or so line addition to the basic block? Shouldn't all the changes be the VD blocks now? Am I missing something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just as the normal ResNet, ResNet_vds use basic blocks in their small versions, and the bigger ResNet_vds use bottleneck blocks. Therefore, we need vd variants of both, basic_block and bottleneck_block.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think this was just github diff being weird, my bad! It looked like you were making a massive addition to the non-vd block, but I see now that's just because the code matches across both. Let's try this out!

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@mattdangerw mattdangerw merged commit be8888d into keras-team:keras-hub Aug 28, 2024
7 checks passed
mattdangerw pushed a commit to mattdangerw/keras-hub that referenced this pull request Sep 10, 2024
* Add ResNet_vd to ResNet backbone

* Addressed requested parameter changes

* Fixed tests and updated comments

* Added new parameters to docstring
mattdangerw pushed a commit that referenced this pull request Sep 11, 2024
* Add ResNet_vd to ResNet backbone

* Addressed requested parameter changes

* Fixed tests and updated comments

* Added new parameters to docstring
mattdangerw pushed a commit that referenced this pull request Sep 13, 2024
* Add ResNet_vd to ResNet backbone

* Addressed requested parameter changes

* Fixed tests and updated comments

* Added new parameters to docstring
mattdangerw pushed a commit that referenced this pull request Sep 17, 2024
* Add ResNet_vd to ResNet backbone

* Addressed requested parameter changes

* Fixed tests and updated comments

* Added new parameters to docstring
divyashreepathihalli added a commit that referenced this pull request Sep 25, 2024
* Add VGG16 backbone (#1737)

* Agg Vgg16 backbone

* update names

* update tests

* update test

* add image classifier

* incorporate review comments

* Update test case

* update backbone test

* add image classifier

* classifier cleanup

* code reformat

* add vgg16 image classifier

* make vgg generic

* update doc string

* update docstring

* add classifier test

* update tests

* update docstring

* address review comments

* code reformat

* update the configs

* address review comments

* fix task saved model test

* update init

* code reformatted

* Add `ResNetBackbone` and `ResNetImageClassifier` (#1765)

* Add ResNetV1 and ResNetV2

* Address comments

* Add CSP DarkNet backbone and classifier (#1774)

* Add CSP DarkNet

* Add CSP DarkNet

* snake_case function names

* change use_depthwise to block_type

* Add `FeaturePyramidBackbone` and port weights from `timm` for `ResNetBackbone` (#1769)

* Add FeaturePyramidBackbone and update ResNetBackbone

* Simplify the implementation

* Fix CI

* Make ResNetBackbone compatible with timm and add FeaturePyramidBackbone

* Add conversion implementation

* Update docstrings

* Address comments

* Add DenseNet (#1775)

* Add DenseNet

* fix testcase

* address comments

* nit

* fix lint errors

* move description

* Add ViTDetBackbone (#1776)

* add vit det vit_det_backbone

* update docstring

* code reformat

* fix tests

* address review comments

* bump year on all files

* address review comments

* rename backbone

* fix tests

* change back to ViT

* address review comments

* update image shape

* Add Mix transformer (#1780)

* Add MixTransformer

* fix testcase

* test changes and comments

* lint fix

* update config list

* modify testcase for 2 layers

* update input_image_shape -> image_shape (#1785)

* update input_image_shape -> image_shape

* update docstring example

* code reformat

* update tests

* Create __init__.py (#1788)

add missing __init__ file to vit_det

* Hack package build script to rename to keras-hub (#1793)

This is a temporary way to test out the keras-hub branch.
- Does a global rename of all symbols during package build.
- Registers the "old" name on symbol export for saving compat.
- Adds a github action to publish every commit to keras-hub as
  a new package.
- Removes our descriptions on PyPI temporarily, until we want
  to message this more broadly.

* Add CLIP and T5XXL for StableDiffusionV3 (#1790)

* Add `CLIPTokenizer`, `T5XXLTokenizer`, `CLIPTextEncoder` and `T5XXLTextEncoder`.

* Make CLIPTextEncoder as Backbone

* Add `T5XXLPreprocessor` and remove `T5XXLTokenizer`

Add `CLIPPreprocessor`

* Use `tf = None` at the top

* Replace manual implementation of `CLIPAttention` with `MultiHeadAttention`

* Add Bounding Box Utils (#1791)

* Bounding box utils

* - Correct test cases

* - Remove hard tensorflow dtype

* - fix api gen

* - Fix import for test cases
- Use setup for converters test case

* - fix api_gen issue

* - FIx api gen

* - Fix api gen error

* - Correct test cases as per new api changes

* mobilenet_v3 added in keras-nlp (#1782)

* mobilenet_v3 added in keras-nlp

* minor bug fixed in mobilenet_v3_backbone

* formatting corrected

* refactoring backbone

* correct_pad_downsample method added

* refactoring backbone

* parameters updated

* Testcaseupdated, expected output shape corrected

* code formatted with black

* testcase updated

* refactoring and description added

* comments updated

* added mobilenet v1 and v2

* merge conflict resolved

* version arg removed, and config options added

* input_shape changed to image_shape in arg

* config updated

* input shape corrected

* comments resolved

* activation function format changed

* minor bug fixed

* minor bug fixed

* added vision_backbone_test

* channel_first bug resolved

* channel_first cases working

* comments  resolved

* formatting fixed

* refactoring

---------

Co-authored-by: ushareng <[email protected]>

* Pkgoogle/efficient net migration (#1778)

* migrating efficientnet models to keras-hub

* merging changes from other sources

* autoformatting pass

* initial consolidation of efficientnet_backbone

* most updates and removing separate implementation

* cleanup, autoformatting, keras generalization

* removed layer examples outside of effiicient net

* many, mainly documentation changes, small test fixes

* Add the ResNet_vd backbone (#1766)

* Add ResNet_vd to ResNet backbone

* Addressed requested parameter changes

* Fixed tests and updated comments

* Added new parameters to docstring

* Add `VAEImageDecoder` for StableDiffusionV3 (#1796)

* Add `VAEImageDecoder` for StableDiffusionV3

* Use `keras.Model` for `VAEImageDecoder` and follows the coding style in `VAEAttention`

* Replace `Backbone` with `keras.Model` in `CLIPTextEncoder` and `T5XXLTextEncoder` (#1802)

* Add pyramid output for densenet, cspDarknet (#1801)

* add pyramid outputs

* fix testcase

* format fix

* make common testcase for pyramid outputs

* change default shape

* simplify testcase

* test case change and add channel axis

* Add `MMDiT` for StableDiffusionV3 (#1806)

* Add `MMDiT`

* Update

* Update

* Update implementation

* Add remaining bbox utils (#1804)

* - Add formats, iou, utils for bounding box

* - Add `AnchorGenerator`, `BoxMatcher` and `NonMaxSupression` layers

* - Remove scope_name  not required.

* use default keras name scope

* - Correct format error

* - Remove layers as of now and keep them at model level till keras core supports them

* - Correct api_gen

* Fix timm conversion for rersnet (#1814)

* Add `StableDiffusion3`

* Fix `_normalize_inputs`

* Separate CLIP encoders from SD3 backbone.

* Simplify `text_to_image` function.

* Address comments

* Minor update and add docstrings.

* Add VGG16 backbone (#1737)

* Agg Vgg16 backbone

* update names

* update tests

* update test

* add image classifier

* incorporate review comments

* Update test case

* update backbone test

* add image classifier

* classifier cleanup

* code reformat

* add vgg16 image classifier

* make vgg generic

* update doc string

* update docstring

* add classifier test

* update tests

* update docstring

* address review comments

* code reformat

* update the configs

* address review comments

* fix task saved model test

* update init

* code reformatted

* Add `ResNetBackbone` and `ResNetImageClassifier` (#1765)

* Add ResNetV1 and ResNetV2

* Address comments

* Add CSP DarkNet backbone and classifier (#1774)

* Add CSP DarkNet

* Add CSP DarkNet

* snake_case function names

* change use_depthwise to block_type

* Add `FeaturePyramidBackbone` and port weights from `timm` for `ResNetBackbone` (#1769)

* Add FeaturePyramidBackbone and update ResNetBackbone

* Simplify the implementation

* Fix CI

* Make ResNetBackbone compatible with timm and add FeaturePyramidBackbone

* Add conversion implementation

* Update docstrings

* Address comments

* Add DenseNet (#1775)

* Add DenseNet

* fix testcase

* address comments

* nit

* fix lint errors

* move description

* Add ViTDetBackbone (#1776)

* add vit det vit_det_backbone

* update docstring

* code reformat

* fix tests

* address review comments

* bump year on all files

* address review comments

* rename backbone

* fix tests

* change back to ViT

* address review comments

* update image shape

* Add Mix transformer (#1780)

* Add MixTransformer

* fix testcase

* test changes and comments

* lint fix

* update config list

* modify testcase for 2 layers

* update input_image_shape -> image_shape (#1785)

* update input_image_shape -> image_shape

* update docstring example

* code reformat

* update tests

* Create __init__.py (#1788)

add missing __init__ file to vit_det

* Hack package build script to rename to keras-hub (#1793)

This is a temporary way to test out the keras-hub branch.
- Does a global rename of all symbols during package build.
- Registers the "old" name on symbol export for saving compat.
- Adds a github action to publish every commit to keras-hub as
  a new package.
- Removes our descriptions on PyPI temporarily, until we want
  to message this more broadly.

* Add CLIP and T5XXL for StableDiffusionV3 (#1790)

* Add `CLIPTokenizer`, `T5XXLTokenizer`, `CLIPTextEncoder` and `T5XXLTextEncoder`.

* Make CLIPTextEncoder as Backbone

* Add `T5XXLPreprocessor` and remove `T5XXLTokenizer`

Add `CLIPPreprocessor`

* Use `tf = None` at the top

* Replace manual implementation of `CLIPAttention` with `MultiHeadAttention`

* Add Bounding Box Utils (#1791)

* Bounding box utils

* - Correct test cases

* - Remove hard tensorflow dtype

* - fix api gen

* - Fix import for test cases
- Use setup for converters test case

* - fix api_gen issue

* - FIx api gen

* - Fix api gen error

* - Correct test cases as per new api changes

* mobilenet_v3 added in keras-nlp (#1782)

* mobilenet_v3 added in keras-nlp

* minor bug fixed in mobilenet_v3_backbone

* formatting corrected

* refactoring backbone

* correct_pad_downsample method added

* refactoring backbone

* parameters updated

* Testcaseupdated, expected output shape corrected

* code formatted with black

* testcase updated

* refactoring and description added

* comments updated

* added mobilenet v1 and v2

* merge conflict resolved

* version arg removed, and config options added

* input_shape changed to image_shape in arg

* config updated

* input shape corrected

* comments resolved

* activation function format changed

* minor bug fixed

* minor bug fixed

* added vision_backbone_test

* channel_first bug resolved

* channel_first cases working

* comments  resolved

* formatting fixed

* refactoring

---------

Co-authored-by: ushareng <[email protected]>

* Pkgoogle/efficient net migration (#1778)

* migrating efficientnet models to keras-hub

* merging changes from other sources

* autoformatting pass

* initial consolidation of efficientnet_backbone

* most updates and removing separate implementation

* cleanup, autoformatting, keras generalization

* removed layer examples outside of effiicient net

* many, mainly documentation changes, small test fixes

* Add the ResNet_vd backbone (#1766)

* Add ResNet_vd to ResNet backbone

* Addressed requested parameter changes

* Fixed tests and updated comments

* Added new parameters to docstring

* Add `VAEImageDecoder` for StableDiffusionV3 (#1796)

* Add `VAEImageDecoder` for StableDiffusionV3

* Use `keras.Model` for `VAEImageDecoder` and follows the coding style in `VAEAttention`

* Replace `Backbone` with `keras.Model` in `CLIPTextEncoder` and `T5XXLTextEncoder` (#1802)

* Add pyramid output for densenet, cspDarknet (#1801)

* add pyramid outputs

* fix testcase

* format fix

* make common testcase for pyramid outputs

* change default shape

* simplify testcase

* test case change and add channel axis

* Add `MMDiT` for StableDiffusionV3 (#1806)

* Add `MMDiT`

* Update

* Update

* Update implementation

* Add remaining bbox utils (#1804)

* - Add formats, iou, utils for bounding box

* - Add `AnchorGenerator`, `BoxMatcher` and `NonMaxSupression` layers

* - Remove scope_name  not required.

* use default keras name scope

* - Correct format error

* - Remove layers as of now and keep them at model level till keras core supports them

* - Correct api_gen

* Fix timm conversion for rersnet (#1814)

* Fix

* Update

* Rename to diffuser and decoder

* Define functional model

* Merge from upstream/master

* Delete old SD3

* Fix copyright

* Rename to keras_hub

* Address comments

* Update

* Fix CI

* Fix bugs occurred in keras3.1

---------

Co-authored-by: Divyashree Sreepathihalli <[email protected]>
Co-authored-by: Sachin Prasad <[email protected]>
Co-authored-by: Matt Watson <[email protected]>
Co-authored-by: Siva Sravana Kumar Neeli <[email protected]>
Co-authored-by: Usha Rengaraju <[email protected]>
Co-authored-by: ushareng <[email protected]>
Co-authored-by: pkgoogle <[email protected]>
Co-authored-by: gowthamkpr <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants