Skip to content

Commit fc485d6

Browse files
Add Mix transformer (#1780)
* Add MixTransformer * fix testcase * test changes and comments * lint fix * update config list * modify testcase for 2 layers
1 parent fd6f977 commit fc485d6

7 files changed

+778
-0
lines changed

keras_nlp/api/models/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,12 @@
165165
MistralPreprocessor,
166166
)
167167
from keras_nlp.src.models.mistral.mistral_tokenizer import MistralTokenizer
168+
from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import (
169+
MiTBackbone,
170+
)
171+
from keras_nlp.src.models.mix_transformer.mix_transformer_classifier import (
172+
MiTImageClassifier,
173+
)
168174
from keras_nlp.src.models.opt.opt_backbone import OPTBackbone
169175
from keras_nlp.src.models.opt.opt_causal_lm import OPTCausalLM
170176
from keras_nlp.src.models.opt.opt_causal_lm_preprocessor import (
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import keras
15+
import numpy as np
16+
from keras import ops
17+
18+
from keras_nlp.src.api_export import keras_nlp_export
19+
from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
20+
from keras_nlp.src.models.mix_transformer.mix_transformer_layers import (
21+
HierarchicalTransformerEncoder,
22+
)
23+
from keras_nlp.src.models.mix_transformer.mix_transformer_layers import (
24+
OverlappingPatchingAndEmbedding,
25+
)
26+
27+
28+
@keras_nlp_export("keras_nlp.models.MiTBackbone")
29+
class MiTBackbone(FeaturePyramidBackbone):
30+
def __init__(
31+
self,
32+
depths,
33+
num_layers,
34+
blockwise_num_heads,
35+
blockwise_sr_ratios,
36+
end_value,
37+
patch_sizes,
38+
strides,
39+
include_rescaling=True,
40+
image_shape=(224, 224, 3),
41+
hidden_dims=None,
42+
**kwargs,
43+
):
44+
"""A Backbone implementing the MixTransformer.
45+
46+
This architecture to be used as a backbone for the SegFormer
47+
architecture [SegFormer: Simple and Efficient Design for Semantic
48+
Segmentation with Transformers](https://arxiv.org/abs/2105.15203)
49+
[Based on the TensorFlow implementation from DeepVision](
50+
https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer)
51+
52+
Args:
53+
depths: The number of transformer encoders to be used per layer in the
54+
network.
55+
num_layers: int. The number of Transformer layers.
56+
blockwise_num_heads: list of integers, the number of heads to use
57+
in the attention computation for each layer.
58+
blockwise_sr_ratios: list of integers, the sequence reduction
59+
ratio to perform for each layer on the sequence before key and
60+
value projections. If set to > 1, a `Conv2D` layer is used to
61+
reduce the length of the sequence.
62+
end_value: The end value of the sequence.
63+
include_rescaling: bool, whether to rescale the inputs. If set
64+
to `True`, inputs will be passed through a `Rescaling(1/255.0)`
65+
layer. Defaults to `True`.
66+
image_shape: optional shape tuple, defaults to (224, 224, 3).
67+
hidden_dims: the embedding dims per hierarchical layer, used as
68+
the levels of the feature pyramid.
69+
patch_sizes: list of integers, the patch_size to apply for each layer.
70+
strides: list of integers, stride to apply for each layer.
71+
72+
Examples:
73+
74+
Using the class with a `backbone`:
75+
76+
```python
77+
images = np.ones(shape=(1, 96, 96, 3))
78+
labels = np.zeros(shape=(1, 96, 96, 1))
79+
backbone = keras_nlp.models.MiTBackbone.from_preset("mit_b0_imagenet")
80+
81+
# Evaluate model
82+
model(images)
83+
84+
# Train model
85+
model.compile(
86+
optimizer="adam",
87+
loss=keras.losses.BinaryCrossentropy(from_logits=False),
88+
metrics=["accuracy"],
89+
)
90+
model.fit(images, labels, epochs=3)
91+
```
92+
"""
93+
dpr = [x for x in np.linspace(0.0, end_value, sum(depths))]
94+
95+
# === Layers ===
96+
cur = 0
97+
patch_embedding_layers = []
98+
transformer_blocks = []
99+
layer_norms = []
100+
101+
for i in range(num_layers):
102+
patch_embed_layer = OverlappingPatchingAndEmbedding(
103+
project_dim=hidden_dims[i],
104+
patch_size=patch_sizes[i],
105+
stride=strides[i],
106+
name=f"patch_and_embed_{i}",
107+
)
108+
patch_embedding_layers.append(patch_embed_layer)
109+
110+
transformer_block = [
111+
HierarchicalTransformerEncoder(
112+
project_dim=hidden_dims[i],
113+
num_heads=blockwise_num_heads[i],
114+
sr_ratio=blockwise_sr_ratios[i],
115+
drop_prob=dpr[cur + k],
116+
name=f"hierarchical_encoder_{i}_{k}",
117+
)
118+
for k in range(depths[i])
119+
]
120+
transformer_blocks.append(transformer_block)
121+
cur += depths[i]
122+
layer_norms.append(keras.layers.LayerNormalization())
123+
124+
# === Functional Model ===
125+
image_input = keras.layers.Input(shape=image_shape)
126+
x = image_input
127+
128+
if include_rescaling:
129+
x = keras.layers.Rescaling(scale=1 / 255)(x)
130+
131+
pyramid_outputs = {}
132+
for i in range(num_layers):
133+
# Compute new height/width after the `proj`
134+
# call in `OverlappingPatchingAndEmbedding`
135+
stride = strides[i]
136+
new_height, new_width = (
137+
int(ops.shape(x)[1] / stride),
138+
int(ops.shape(x)[2] / stride),
139+
)
140+
141+
x = patch_embedding_layers[i](x)
142+
for blk in transformer_blocks[i]:
143+
x = blk(x)
144+
x = layer_norms[i](x)
145+
x = keras.layers.Reshape(
146+
(new_height, new_width, -1), name=f"output_level_{i}"
147+
)(x)
148+
pyramid_outputs[f"P{i + 1}"] = x
149+
150+
super().__init__(inputs=image_input, outputs=x, **kwargs)
151+
152+
# === Config ===
153+
self.depths = depths
154+
self.include_rescaling = include_rescaling
155+
self.image_shape = image_shape
156+
self.hidden_dims = hidden_dims
157+
self.pyramid_outputs = pyramid_outputs
158+
self.num_layers = num_layers
159+
self.blockwise_num_heads = blockwise_num_heads
160+
self.blockwise_sr_ratios = blockwise_sr_ratios
161+
self.end_value = end_value
162+
self.patch_sizes = patch_sizes
163+
self.strides = strides
164+
165+
def get_config(self):
166+
config = super().get_config()
167+
config.update(
168+
{
169+
"depths": self.depths,
170+
"include_rescaling": self.include_rescaling,
171+
"hidden_dims": self.hidden_dims,
172+
"image_shape": self.image_shape,
173+
"num_layers": self.num_layers,
174+
"blockwise_num_heads": self.blockwise_num_heads,
175+
"blockwise_sr_ratios": self.blockwise_sr_ratios,
176+
"end_value": self.end_value,
177+
"patch_sizes": self.patch_sizes,
178+
"strides": self.strides,
179+
}
180+
)
181+
return config
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import pytest
17+
from keras import models
18+
19+
from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import (
20+
MiTBackbone,
21+
)
22+
from keras_nlp.src.tests.test_case import TestCase
23+
24+
25+
class MiTBackboneTest(TestCase):
26+
def setUp(self):
27+
self.init_kwargs = {
28+
"depths": [2, 2],
29+
"include_rescaling": True,
30+
"image_shape": (16, 16, 3),
31+
"hidden_dims": [4, 8],
32+
"num_layers": 2,
33+
"blockwise_num_heads": [1, 2],
34+
"blockwise_sr_ratios": [8, 4],
35+
"end_value": 0.1,
36+
"patch_sizes": [7, 3],
37+
"strides": [4, 2],
38+
}
39+
self.input_size = 16
40+
self.input_data = np.ones(
41+
(2, self.input_size, self.input_size, 3), dtype="float32"
42+
)
43+
44+
def test_backbone_basics(self):
45+
self.run_backbone_test(
46+
cls=MiTBackbone,
47+
init_kwargs=self.init_kwargs,
48+
input_data=self.input_data,
49+
expected_output_shape=(2, 2, 2, 8),
50+
run_quantization_check=False,
51+
run_mixed_precision_check=False,
52+
)
53+
54+
def test_pyramid_output_format(self):
55+
init_kwargs = self.init_kwargs
56+
backbone = MiTBackbone(**init_kwargs)
57+
model = models.Model(backbone.inputs, backbone.pyramid_outputs)
58+
output_data = model(self.input_data)
59+
60+
self.assertIsInstance(output_data, dict)
61+
self.assertEqual(
62+
list(output_data.keys()), list(backbone.pyramid_outputs.keys())
63+
)
64+
self.assertEqual(list(output_data.keys()), ["P1", "P2"])
65+
for k, v in output_data.items():
66+
size = self.input_size // (2 ** (int(k[1:]) + 1))
67+
self.assertEqual(tuple(v.shape[:3]), (2, size, size))
68+
69+
@pytest.mark.large
70+
def test_saved_model(self):
71+
self.run_model_saving_test(
72+
cls=MiTBackbone,
73+
init_kwargs=self.init_kwargs,
74+
input_data=self.input_data,
75+
)

0 commit comments

Comments
 (0)