Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Video Swin model adds to kerashub #1981

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,15 @@
)
from keras_hub.src.models.xlnet.xlnet_backbone import XLNetBackbone
from keras_hub.src.tokenizers.tokenizer import Tokenizer
from keras_hub.src.models.video_swin.video_swin_aliases import (
VideoSwinBBackbone,
)
from keras_hub.src.models.video_swin.video_swin_aliases import (
VideoSwinSBackbone,
)
from keras_hub.src.models.video_swin.video_swin_aliases import (
VideoSwinTBackbone,
)
from keras_hub.src.models.video_swin.video_swin_backbone import (
VideoSwinBackbone,
)
4 changes: 4 additions & 0 deletions keras_hub/src/models/video_swin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from keras_hub.src.models.video_swin.video_swin_backbone_presets import backbone_presets
from keras_hub.src.models.video_swin.video_swin_backbone import VideoSwinBackbone
from keras_hub.src.utils.preset_utils import register_presets
register_presets(backbone_presets, VideoSwinBackbone)
237 changes: 237 additions & 0 deletions keras_hub/src/models/video_swin/video_swin_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# Copyright 2024 The Kerashub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from functools import partial

import numpy as np
from keras import layers

from keras_hub.src.api_export import keras_hub_export
# from keras_hub.src.backend import keras
import keras
from keras_hub.src.models import utils
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.video_swin.video_swin_backbone_presets import ( # noqa: E501
backbone_presets,
)
from keras_hub.src.models.video_swin.video_swin_backbone_presets import ( # noqa: E501
backbone_presets_with_weights,
)
from keras_hub.src.models.video_swin.video_swin_layers import (
VideoSwinBasicLayer,
)
from keras_hub.src.models.video_swin.video_swin_layers import (
VideoSwinPatchingAndEmbedding,
)
from keras_hub.src.models.video_swin.video_swin_layers import (
VideoSwinPatchMerging,
)
from keras_hub.src.utils.python_utils import classproperty


@keras_hub_export("keras_hub_export.models.VideoSwinBackbone", package="keras_hub_export.models")
class VideoSwinBackbone(Backbone):
"""A Video Swin Transformer backbone model.
References:
- [Video Swin Transformer](https://arxiv.org/abs/2106.13230)
- [Official Code](https://github.com/SwinTransformer/Video-Swin-Transformer)

Args:
input_shape : The size of the input video in
`(depth, height, width, channel)` format.
Defaults to `(32, 224, 224, 3)`.
include_rescaling : Whether to rescale the inputs. If
set to `True`, inputs will be passed through a `Rescaling(1/255.0)` layer
and normalize with mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
Defaults to `False`.
patch_size : The patch size for depth, height, and width
dimensions respectively. Default: (2,4,4).
embed_dim : Number of linear projection output channels.
Default to 96.
depth : Depth of each Swin Transformer stage.
Default to [2, 2, 6, 2]
num_heads : Number of attention head of each stage.
Default to [3, 6, 12, 24]
window_size : The window size for depth, height, and width
dimensions respectively. Default to [8, 7, 7].
mlp_ratio : Ratio of mlp hidden dim to embedding dim.
Default to 4.
qkv_bias : If True, add a learnable bias to query, key, value.
Default to True.
qk_scale : Override default qk scale of head_dim ** -0.5 if set.
Default to None.
dropout_rate : Float between 0 and 1. Fraction of the input units to drop.
Default: 0.
attn_dropout_rate : Float between 0 and 1. Attention dropout rate.
Default: 0.
drop_path_rate : Float between 0 and 1. Stochastic depth rate.
Default: 0.2.
patch_norm : If True, add layer normalization after patch embedding.
Default to False.

Example:
```python
# Build video swin backbone without top layer
from keras_hub.src.models.video_swin.video_swin_layers import VideoSwinBasicLayer
model = VideoSwinBasicLayer(
input_shape=(8, 256, 256, 3),
)
kernel-loophole marked this conversation as resolved.
Show resolved Hide resolved
videos = keras.ops.ones((1, 8, 256, 256, 3))
outputs = model.predict(videos)
```
""" # noqa: E501

def __init__(
self,
*,
input_shape=(32, 224, 224, 3),
kernel-loophole marked this conversation as resolved.
Show resolved Hide resolved
input_tensor=None,
embed_dim=96,
patch_size=[2, 4, 4],
window_size=[8, 7, 7],
mlp_ratio=4.0,
patch_norm=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.2,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
qkv_bias=True,
qk_scale=None,
**kwargs,
):
# Parse input specification.
input_spec = utils.parse_model_inputs(
input_shape, input_tensor, name="videos"
)

# Check that the input video is well specified.
if (
input_spec.shape[-4] is None
or input_spec.shape[-3] is None
or input_spec.shape[-2] is None
):
raise ValueError(
"Depth, height and width of the video must be specified"
" in `input_shape`."
)

x = input_spec

# if include_rescaling:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by changing the default value of scaling can change the model behavior

# # Use common rescaling strategy across keras_cv
# x = keras.layers.Rescaling(1.0 / 255.0)(x)

# # VideoSwin scales inputs based on the ImageNet mean/stddev.
# # Officially, Videw Swin takes tensor of [0-255] ranges.
# # And use mean=[123.675, 116.28, 103.53] and
# # std=[58.395, 57.12, 57.375] for normalization.
# # So, if include_rescaling is set to True, then, to match with the
# # official scores, following normalization should be added.
# x = layers.Normalization(
# mean=[0.485, 0.456, 0.406],
# variance=[0.229**2, 0.224**2, 0.225**2],
# )(x)

norm_layer = partial(layers.LayerNormalization, epsilon=1e-05)

x = VideoSwinPatchingAndEmbedding(
patch_size=patch_size,
embed_dim=embed_dim,
norm_layer=norm_layer if patch_norm else None,
name="videoswin_patching_and_embedding",
)(x)
x = layers.Dropout(drop_rate, name="pos_drop")(x)

dpr = np.linspace(0.0, drop_path_rate, sum(depths)).tolist()
num_layers = len(depths)
for i in range(num_layers):
layer = VideoSwinBasicLayer(
input_dim=int(embed_dim * 2**i),
depth=depths[i],
num_heads=num_heads[i],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[sum(depths[:i]) : sum(depths[: i + 1])],
norm_layer=norm_layer,
downsampling_layer=(
VideoSwinPatchMerging if (i < num_layers - 1) else None
),
name=f"videoswin_basic_layer_{i + 1}",
)
x = layer(x)

x = norm_layer(axis=-1, epsilon=1e-05, name="videoswin_top_norm")(x)
super().__init__(inputs=input_spec, outputs=x, **kwargs)

# self.include_rescaling = include_rescaling
self.input_tensor = input_tensor
self.embed_dim = embed_dim
self.patch_size = patch_size
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.norm_layer = norm_layer
self.patch_norm = patch_norm
self.drop_rate = drop_rate
self.attn_drop_rate = attn_drop_rate
self.drop_path_rate = drop_path_rate
self.num_layers = len(depths)
self.num_heads = num_heads
self.qkv_bias = qkv_bias
self.qk_scale = qk_scale
self.depths = depths

def get_config(self):
config = super().get_config()
config.update(
{
"input_shape": self.input_shape[1:],
"input_tensor": self.input_tensor,
"embed_dim": self.embed_dim,
"patch_norm": self.patch_norm,
"window_size": self.window_size,
"patch_size": self.patch_size,
"mlp_ratio": self.mlp_ratio,
"drop_rate": self.drop_rate,
"drop_path_rate": self.drop_path_rate,
"attn_drop_rate": self.attn_drop_rate,
"depths": self.depths,
"num_heads": self.num_heads,
"qkv_bias": self.qkv_bias,
"qk_scale": self.qk_scale,
}
)
return config

@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return copy.deepcopy(backbone_presets)

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include
weights."""
return copy.deepcopy(backbone_presets_with_weights)

@property
def pyramid_level_inputs(self):
raise NotImplementedError(
"The `VideoSwinBackbone` model doesn't compute"
" pyramid level features."
)
136 changes: 136 additions & 0 deletions keras_hub/src/models/video_swin/video_swin_backbone_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright 2024 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Video Swin model preset configurations."""

backbone_presets_no_weights = {
"videoswin_tiny": {
"metadata": {
"description": ("A tiny Video Swin backbone architecture."),
"params": 27_850_470,
"official_name": "VideoSwinT",
"path": "video_swin",
},
},
"videoswin_small": {
"metadata": {
"description": ("A small Video Swin backbone architecture."),
"params": 49_509_078,
"official_name": "VideoSwinS",
"path": "video_swin",
},
},
"videoswin_base": {
"metadata": {
"description": ("A base Video Swin backbone architecture."),
"params": 87_638_984,
"official_name": "VideoSwinB",
"path": "video_swin",
},
},
}

backbone_presets_with_weights = {
"videoswin_tiny_kinetics400": {
"metadata": {
"description": (
"A tiny Video Swin backbone architecture. "
"It is pretrained on ImageNet 1K dataset, and "
"trained on Kinetics 400 dataset. "
),
"params": 27_850_470,
"official_name": "VideoSwinT",
"path": "video_swin",
},
},
"videoswin_small_kinetics400": {
"metadata": {
"description": (
"A small Video Swin backbone architecture. "
"It is pretrained on ImageNet 1K dataset, and "
"trained on Kinetics 400 dataset. "
"Published weight is capable of scoring "
"80.6% top1 and 94.5% top5 accuracy on the "
"Kinetics 400 dataset"
),
"params": 49_509_078,
"official_name": "VideoSwinS",
"path": "video_swin",
},
},
"videoswin_base_kinetics400": {
"metadata": {
"description": (
"A base Video Swin backbone architecture. "
"It is pretrained on ImageNet 1K dataset, and "
"trained on Kinetics 400 dataset. "
"Published weight is capable of scoring "
"80.6% top1 and 94.6% top5 accuracy on the "
"Kinetics 400 dataset"
),
"params": 87_638_984,
"official_name": "VideoSwinB",
"path": "video_swin",
},
},
"videoswin_base_kinetics400_imagenet22k": {
"metadata": {
"description": (
"A base Video Swin backbone architecture. "
"It is pretrained on ImageNet 22K dataset, and "
"trained on Kinetics 400 dataset. "
"Published weight is capable of scoring "
"82.7% top1 and 95.5% top5 accuracy on the "
"Kinetics 400 dataset"
),
"params": 87_638_984,
"official_name": "VideoSwinB",
"path": "video_swin",
},
},
"videoswin_base_kinetics600_imagenet22k": {
"metadata": {
"description": (
"A base Video Swin backbone architecture. "
"It is pretrained on ImageNet 22K dataset, and "
"trained on Kinetics 600 dataset. "
"Published weight is capable of scoring "
"84.0% top1 and 96.5% top5 accuracy on the "
"Kinetics 600 dataset"
),
"params": 87_638_984,
"official_name": "VideoSwinB",
"path": "video_swin",
},
},
"videoswin_base_something_something_v2": {
"metadata": {
"description": (
"A base Video Swin backbone architecture. "
"It is pretrained on Kinetics 400 dataset, and "
"trained on Something Something V2 dataset. "
"Published weight is capable of scoring "
"69.6% top1 and 92.7% top5 accuracy on the "
"Kinetics 400 dataset"
),
"params": 87_638_984,
"official_name": "VideoSwinB",
"path": "video_swin",
},
},
}

backbone_presets = {
**backbone_presets_no_weights,
**backbone_presets_with_weights,
}
Loading