-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathefficientnet.py
335 lines (277 loc) · 10.2 KB
/
efficientnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
# coding=utf-8
# Copyright 2024 The Perch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""EfficientNet.
Implementation of the EfficientNet model in Flax.
"""
import dataclasses
import enum
import math
from typing import Callable, NamedTuple
from aqt.jax.v2 import aqt_conv_general
from aqt.jax.v2 import config as aqt_cfg # pylint: disable=unused-import
from chirp.models import layers
from flax import linen as nn
import flax.typing as flax_typing
import jax
from jax import numpy as jnp
class EfficientNetModel(enum.Enum):
"""Different variants of EfficientNet."""
B0 = "b0"
B1 = "b1"
B2 = "b2"
B3 = "b3"
B4 = "b4"
B5 = "b5"
B6 = "b6"
B7 = "b7"
B8 = "b8"
L2 = "l2"
class EfficientNetStage(NamedTuple):
"""Definition of a single stage in EfficientNet."""
num_blocks: int
features: int
kernel_size: tuple[int, int]
strides: int
expand_ratio: int
# The values for EfficientNet-B0. The other variants are scalings of these.
# See table 1 in the paper or
# https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_builder.py
STEM_FEATURES = 32
STAGES = [
EfficientNetStage(1, 16, (3, 3), 1, 1),
EfficientNetStage(2, 24, (3, 3), 2, 6),
EfficientNetStage(2, 40, (5, 5), 2, 6),
EfficientNetStage(3, 80, (3, 3), 2, 6),
EfficientNetStage(3, 112, (5, 5), 1, 6),
EfficientNetStage(4, 192, (5, 5), 2, 6),
EfficientNetStage(1, 320, (3, 3), 1, 6),
]
HEAD_FEATURES = 1280
REDUCTION_RATIO = 4
class EfficientNetScaling(NamedTuple):
"""Scaling for different model variants."""
width_coefficient: float
depth_coefficient: float
dropout_rate: float
SCALINGS = {
EfficientNetModel.B0: EfficientNetScaling(1.0, 1.0, 0.2),
EfficientNetModel.B1: EfficientNetScaling(1.0, 1.1, 0.2),
EfficientNetModel.B2: EfficientNetScaling(1.1, 1.2, 0.3),
EfficientNetModel.B3: EfficientNetScaling(1.2, 1.4, 0.3),
EfficientNetModel.B4: EfficientNetScaling(1.4, 1.8, 0.4),
EfficientNetModel.B5: EfficientNetScaling(1.6, 2.2, 0.4),
EfficientNetModel.B6: EfficientNetScaling(1.8, 2.6, 0.5),
EfficientNetModel.B7: EfficientNetScaling(2.0, 3.1, 0.5),
EfficientNetModel.B8: EfficientNetScaling(2.2, 3.6, 0.5),
EfficientNetModel.L2: EfficientNetScaling(4.3, 5.3, 0.5),
}
def round_features(
features: int, width_coefficient: float, depth_divisor: int = 8
) -> int:
"""Round number of filters based on width multiplier."""
features *= width_coefficient
new_features = max(
depth_divisor,
int(features + depth_divisor / 2) // depth_divisor * depth_divisor,
)
if new_features < 0.9 * features:
new_features += depth_divisor
return int(new_features)
def round_num_blocks(num_blocks: int, depth_coefficient: float) -> int:
"""Round number of blocks based on depth multiplier."""
return int(math.ceil(depth_coefficient * num_blocks))
@dataclasses.dataclass
class OpSet:
activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
sigmoid: Callable[[jnp.ndarray], jnp.ndarray] = nn.sigmoid
stem_activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.swish
head_activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.swish
dot_general: flax_typing.DotGeneralT | None = None
conv_general_dilated: flax_typing.ConvGeneralDilatedT | None = None
op_sets = {
"default": OpSet(),
"qat": OpSet(
activation=nn.relu,
sigmoid=nn.hard_sigmoid,
stem_activation=nn.hard_swish,
head_activation=nn.hard_swish,
dot_general=jax.lax.dot_general,
conv_general_dilated=aqt_conv_general.make_conv_general_dilated(
aqt_cfg.conv_general_dilated_make(spatial_dimensions=2)
),
),
}
class Stem(nn.Module):
"""The stem of an EfficientNet model.
The stem is the first layer, which is equivalent for all variations of
EfficientNet.
Attributes:
features: The number of filters.
"""
features: int
conv_general_dilated: flax_typing.ConvGeneralDilatedT | None = None
activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.swish
@nn.compact
def __call__(
self, inputs: jnp.ndarray, use_running_average: bool
) -> jnp.ndarray:
"""Applies the first step of EfficientNet to the inputs.
Args:
inputs: Inputs should be of shape `(batch size, height, width, channels)`.
use_running_average: Used to decide whether to use running statistics in
BatchNorm (test mode), or the current batch's statistics (train mode).
Returns:
A JAX array of `(batch size, height, width, features)`.
"""
x = nn.Conv(
features=self.features,
kernel_size=(3, 3),
strides=2,
use_bias=False,
conv_general_dilated=self.conv_general_dilated,
padding="VALID",
)(inputs)
x = nn.BatchNorm(use_running_average=use_running_average)(x)
x = self.activation(x)
return x
class Head(nn.Module):
"""The head of an EfficientNet model.
The head is the last layer, which is equivalent for all variations of
EfficientNet.
Attributes:
features: The number of filters.
conv_general_dilated: Convolution op.
"""
features: int
activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.swish
conv_general_dilated: flax_typing.ConvGeneralDilatedT | None = None
@nn.compact
def __call__(
self, inputs: jnp.ndarray, use_running_average: bool
) -> jnp.ndarray:
"""Applies the last step of EfficientNet to the inputs.
Args:
inputs: Inputs should be of shape `(batch size, height, width, channels)`.
use_running_average: Used to decide whether to use running statistics in
BatchNorm (test mode), or the current batch's statistics (train mode).
Returns:
A JAX array of `(batch size, height, width, features)`.
"""
x = nn.Conv(
features=self.features,
kernel_size=(1, 1),
strides=1,
use_bias=False,
conv_general_dilated=self.conv_general_dilated,
)(inputs)
x = nn.BatchNorm(use_running_average=use_running_average)(x)
x = self.activation(x)
return x
class EfficientNet(nn.Module):
"""EfficientNet model.
Attributes:
model: The variant of EfficientNet model to use.
include_top: If true, the model applies average pooling, flattens the
output, and applies dropout. Note that this is different from Keras's
`include_top` argument, which applies an additional linear transformation.
survival_probability: The survival probability to use for stochastic depth.
head: Optional Flax module to use as custom head.
stem: Optional Flax module to use as custom stem.
op_set: Named set of ops to use.
"""
model: EfficientNetModel
include_top: bool = True
survival_probability: float = 0.8
head: nn.Module | None = None
stem: nn.Module | None = None
op_set: str = "default"
@nn.compact
def __call__(
self,
inputs: jnp.ndarray,
train: bool,
use_running_average: bool | None = None,
) -> jnp.ndarray:
"""Applies EfficientNet to the inputs.
Note that this model does not include the final pooling and fully connected
layers.
Args:
inputs: Inputs should be of shape `(batch size, height, width, channels)`.
train: Whether this is training. This affects Dropout behavior, and also
affects BatchNorm behavior if 'use_running_average' is set to None.
use_running_average: Optional, used to decide whether to use running
statistics in BatchNorm (test mode), or the current batch's statistics
(train mode). If not specified (or specified to None), default to 'not
train'.
Returns:
A JAX array of `(batch size, height, width, features)` if `include_top` is
false. If `include_top` is true the output is `(batch_size, features)`.
"""
ops = op_sets[self.op_set]
if use_running_average is None:
use_running_average = not train
scaling = SCALINGS[self.model]
if self.stem is None:
features = round_features(STEM_FEATURES, scaling.width_coefficient)
stem = Stem(
features,
activation=ops.stem_activation,
conv_general_dilated=ops.conv_general_dilated,
)
else:
stem = self.stem
x = stem(inputs, use_running_average=use_running_average)
for stage in STAGES:
num_blocks = round_num_blocks(stage.num_blocks, scaling.depth_coefficient)
for block in range(num_blocks):
# MBConv block with squeeze-and-excitation
strides = stage.strides if block == 0 else 1
features = round_features(stage.features, scaling.width_coefficient)
mbconv = layers.MBConv(
features=features,
strides=strides,
expand_ratio=stage.expand_ratio,
kernel_size=stage.kernel_size,
batch_norm=True,
reduction_ratio=REDUCTION_RATIO,
activation=ops.activation,
sigmoid_activation=ops.sigmoid,
dot_general=ops.dot_general,
conv_general_dilated=ops.conv_general_dilated,
)
y = mbconv(x, train=train, use_running_average=use_running_average)
# Stochastic depth
if block > 0 and self.survival_probability:
y = nn.Dropout(
1 - self.survival_probability,
broadcast_dims=(1, 2, 3),
deterministic=not train,
)(y)
# Skip connections
x = y if block == 0 else y + x
if self.head is None:
features = round_features(HEAD_FEATURES, scaling.width_coefficient)
head = Head(
features,
activation=ops.head_activation,
conv_general_dilated=ops.conv_general_dilated,
)
else:
head = self.head
x = head(x, use_running_average=use_running_average)
if self.include_top:
x = jnp.mean(x, axis=(1, 2))
x = nn.Dropout(rate=scaling.dropout_rate, deterministic=not train)(x)
return x