Skip to content

Add Mix transformer #1780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions keras_nlp/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@
MistralPreprocessor,
)
from keras_nlp.src.models.mistral.mistral_tokenizer import MistralTokenizer
from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import (
MiTBackbone,
)
from keras_nlp.src.models.mix_transformer.mix_transformer_classifier import (
MiTImageClassifier,
)
from keras_nlp.src.models.opt.opt_backbone import OPTBackbone
from keras_nlp.src.models.opt.opt_causal_lm import OPTCausalLM
from keras_nlp.src.models.opt.opt_causal_lm_preprocessor import (
Expand Down
13 changes: 13 additions & 0 deletions keras_nlp/src/models/mix_transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
181 changes: 181 additions & 0 deletions keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras
import numpy as np
from keras import ops

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
from keras_nlp.src.models.mix_transformer.mix_transformer_layers import (
HierarchicalTransformerEncoder,
)
from keras_nlp.src.models.mix_transformer.mix_transformer_layers import (
OverlappingPatchingAndEmbedding,
)


@keras_nlp_export("keras_nlp.models.MiTBackbone")
class MiTBackbone(FeaturePyramidBackbone):
def __init__(
self,
depths,
num_layers,
blockwise_num_heads,
blockwise_sr_ratios,
end_value,
patch_sizes,
strides,
include_rescaling=True,
image_shape=(224, 224, 3),
hidden_dims=None,
**kwargs,
):
"""A Backbone implementing the MixTransformer.

This architecture to be used as a backbone for the SegFormer
architecture [SegFormer: Simple and Efficient Design for Semantic
Segmentation with Transformers](https://arxiv.org/abs/2105.15203)
[Based on the TensorFlow implementation from DeepVision](
https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer)

Args:
depths: The number of transformer encoders to be used per layer in the
network.
num_layers: int. The number of Transformer layers.
blockwise_num_heads: list of integers, the number of heads to use
in the attention computation for each layer.
blockwise_sr_ratios: list of integers, the sequence reduction
ratio to perform for each layer on the sequence before key and
value projections. If set to > 1, a `Conv2D` layer is used to
reduce the length of the sequence.
end_value: The end value of the sequence.
include_rescaling: bool, whether to rescale the inputs. If set
to `True`, inputs will be passed through a `Rescaling(1/255.0)`
layer. Defaults to `True`.
image_shape: optional shape tuple, defaults to (224, 224, 3).
hidden_dims: the embedding dims per hierarchical layer, used as
the levels of the feature pyramid.
patch_sizes: list of integers, the patch_size to apply for each layer.
strides: list of integers, stride to apply for each layer.

Examples:

Using the class with a `backbone`:

```python
images = np.ones(shape=(1, 96, 96, 3))
labels = np.zeros(shape=(1, 96, 96, 1))
backbone = keras_nlp.models.MiTBackbone.from_preset("mit_b0_imagenet")

# Evaluate model
model(images)

# Train model
model.compile(
optimizer="adam",
loss=keras.losses.BinaryCrossentropy(from_logits=False),
metrics=["accuracy"],
)
model.fit(images, labels, epochs=3)
```
"""
dpr = [x for x in np.linspace(0.0, end_value, sum(depths))]

# === Layers ===
cur = 0
patch_embedding_layers = []
transformer_blocks = []
layer_norms = []

for i in range(num_layers):
patch_embed_layer = OverlappingPatchingAndEmbedding(
project_dim=hidden_dims[i],
patch_size=patch_sizes[i],
stride=strides[i],
name=f"patch_and_embed_{i}",
)
patch_embedding_layers.append(patch_embed_layer)

transformer_block = [
HierarchicalTransformerEncoder(
project_dim=hidden_dims[i],
num_heads=blockwise_num_heads[i],
sr_ratio=blockwise_sr_ratios[i],
drop_prob=dpr[cur + k],
name=f"hierarchical_encoder_{i}_{k}",
)
for k in range(depths[i])
]
transformer_blocks.append(transformer_block)
cur += depths[i]
layer_norms.append(keras.layers.LayerNormalization())

# === Functional Model ===
image_input = keras.layers.Input(shape=image_shape)
x = image_input

if include_rescaling:
x = keras.layers.Rescaling(scale=1 / 255)(x)

pyramid_outputs = {}
for i in range(num_layers):
# Compute new height/width after the `proj`
# call in `OverlappingPatchingAndEmbedding`
stride = strides[i]
new_height, new_width = (
int(ops.shape(x)[1] / stride),
int(ops.shape(x)[2] / stride),
)

x = patch_embedding_layers[i](x)
for blk in transformer_blocks[i]:
x = blk(x)
x = layer_norms[i](x)
x = keras.layers.Reshape(
(new_height, new_width, -1), name=f"output_level_{i}"
)(x)
pyramid_outputs[f"P{i + 1}"] = x

super().__init__(inputs=image_input, outputs=x, **kwargs)

# === Config ===
self.depths = depths
self.include_rescaling = include_rescaling
self.image_shape = image_shape
self.hidden_dims = hidden_dims
self.pyramid_outputs = pyramid_outputs
self.num_layers = num_layers
self.blockwise_num_heads = blockwise_num_heads
self.blockwise_sr_ratios = blockwise_sr_ratios
self.end_value = end_value
self.patch_sizes = patch_sizes
self.strides = strides

def get_config(self):
config = super().get_config()
config.update(
{
"depths": self.depths,
"include_rescaling": self.include_rescaling,
"hidden_dims": self.hidden_dims,
"image_shape": self.image_shape,
"num_layers": self.num_layers,
"blockwise_num_heads": self.blockwise_num_heads,
"blockwise_sr_ratios": self.blockwise_sr_ratios,
"end_value": self.end_value,
"patch_sizes": self.patch_sizes,
"strides": self.strides,
}
)
return config
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest
from keras import models

from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import (
MiTBackbone,
)
from keras_nlp.src.tests.test_case import TestCase


class MiTBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"depths": [2, 2],
"include_rescaling": True,
"image_shape": (16, 16, 3),
"hidden_dims": [4, 8],
"num_layers": 2,
"blockwise_num_heads": [1, 2],
"blockwise_sr_ratios": [8, 4],
"end_value": 0.1,
"patch_sizes": [7, 3],
"strides": [4, 2],
}
self.input_size = 16
self.input_data = np.ones(
(2, self.input_size, self.input_size, 3), dtype="float32"
)

def test_backbone_basics(self):
self.run_backbone_test(
cls=MiTBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 2, 2, 8),
run_quantization_check=False,
run_mixed_precision_check=False,
)

def test_pyramid_output_format(self):
init_kwargs = self.init_kwargs
backbone = MiTBackbone(**init_kwargs)
model = models.Model(backbone.inputs, backbone.pyramid_outputs)
output_data = model(self.input_data)

self.assertIsInstance(output_data, dict)
self.assertEqual(
list(output_data.keys()), list(backbone.pyramid_outputs.keys())
)
self.assertEqual(list(output_data.keys()), ["P1", "P2"])
for k, v in output_data.items():
size = self.input_size // (2 ** (int(k[1:]) + 1))
self.assertEqual(tuple(v.shape[:3]), (2, size, size))

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=MiTBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
Loading
Loading