Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Normalizations #14

Merged
merged 30 commits into from
Mar 9, 2019
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c30b75a
edited buildfile for normalizations. Implemented GroupNorm,InstanceNo…
Smokrow Jan 15, 2019
0e7674b
Resolved Comments
Smokrow Jan 17, 2019
65a5495
found bug in normalizations init
Smokrow Feb 9, 2019
892110c
minor changes
Smokrow Feb 10, 2019
57c60a7
Merge remote-tracking branch 'upstream/master' into dev/tests
Smokrow Feb 10, 2019
d5beb61
added function for easy testing.
Smokrow Feb 10, 2019
7a361ce
clean up
Smokrow Feb 10, 2019
2c174fc
found bug in BUILD File
Smokrow Feb 13, 2019
7c32461
fixed signature bug and added tests
Smokrow Feb 17, 2019
22f3073
Merge remote-tracking branch 'upstream/master' into dev/layer_norm
Smokrow Feb 17, 2019
0b04162
Update maxout.py
Smokrow Feb 17, 2019
095d91e
small change to variable name
Smokrow Feb 17, 2019
3b6d4e6
cleaned BUILD file
Smokrow Feb 24, 2019
b288cca
cleaned docstring
Smokrow Feb 24, 2019
b7e3d77
did some refactoring
Smokrow Feb 24, 2019
55cb158
refactored call function
Smokrow Feb 26, 2019
ee55d75
Merge branch 'dev/layer_norm' of https://github.com/Smokrow/addons in…
Smokrow Mar 1, 2019
540492e
fixed BUILD file
Smokrow Mar 1, 2019
f980aa5
implemented batch_normalization from tf nn
Smokrow Mar 1, 2019
576961d
added normalization and reshape test
Smokrow Mar 4, 2019
918eeb7
added axis check
Smokrow Mar 4, 2019
d2c1afd
added manual layer test
Smokrow Mar 4, 2019
4ebd907
added tests to check normalization with numpy
Smokrow Mar 5, 2019
37244c4
Included some comments from @ppwwyyxx
Smokrow Mar 5, 2019
0669466
beautified
Smokrow Mar 5, 2019
25d5569
Merge branch 'master' into dev/layer_norm
Smokrow Mar 5, 2019
b4613ae
Update normalizations.py
Smokrow Mar 7, 2019
fcd1639
Update normalizations.py
Smokrow Mar 7, 2019
e8ddabe
Merge branch 'master' into pr/Smokrow/14
seanpmorgan Mar 9, 2019
429ded2
* Standardize formatting with project
seanpmorgan Mar 9, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion tensorflow_addons/layers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ py_library(
"python/maxout.py",
"python/poincare.py",
"python/wrappers.py",
],
"python/normalizations.py"
Smokrow marked this conversation as resolved.
Show resolved Hide resolved
]),
srcs_version = "PY2AND3",
deps = [
"//tensorflow_addons/utils:utils_py",
Expand Down Expand Up @@ -43,6 +44,27 @@ py_test(
],
)

py_test(
name = "poincare_py_test",
Smokrow marked this conversation as resolved.
Show resolved Hide resolved
size = "small",
srcs = [
"python/poincare_test.py",
],
main = "python/poincare_test.py",
srcs_version = "PY2AND3",
)

py_test(
name = "layers_normalizations_py_test",
srcs = [
"python/normalizations_test.py",
],
main = "python/normalizations_test.py",
Smokrow marked this conversation as resolved.
Show resolved Hide resolved
deps = [
":layers_py",
],
)

py_test(
name = "poincare_py_test",
size = "small",
Expand Down
283 changes: 283 additions & 0 deletions tensorflow_addons/layers/python/normalizations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

# Orginal implementation from keras_contrib/layer/normalization

from tensorflow.keras.layers import Layer, InputSpec
from tensorflow.keras import initializers, regularizers, constraints
Smokrow marked this conversation as resolved.
Show resolved Hide resolved
from tensorflow.keras import backend as K
Smokrow marked this conversation as resolved.
Show resolved Hide resolved
from tensorflow.python.ops import nn

Smokrow marked this conversation as resolved.
Show resolved Hide resolved
class GroupNormalization(Layer):
"""Group normalization layer.
Group Normalization divides the channels into groups and computes
Smokrow marked this conversation as resolved.
Show resolved Hide resolved
within each group
the mean and variance for normalization.
Group Normalization's computation is independent
of batch sizes, and its accuracy is stable in a wide range of batch sizes.
Relation to Layer Normalization:
If the number of groups is set to 1, then this operation becomes identical to
Layer Normalization.
Relation to Instance Normalization:
If the number of groups is set to the
input dimension (number of groups is equal
to number of channels), then this operation becomes
identical to Instance Normalization.
# Arguments
groups: Integer, the number of groups for Group Normalization.
Can be in the range [1, N] where N is the input dimension.
The input dimension must be divisible by the number of groups.
axis: Integer, the axis that should be normalized
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of 4D input tensor, the axis that's normalized is either C/G, H, W, or Batch, G, depend on your definition of "be normalized". But it is in no way "C".

BatchNorm layer takes axis=C, therefore if you want to make an analogy here, it would be axis=[Batch, G]. This analogy is a bit ugly, so I think a better way is to still define axis to be the channel dimension, and write a clearer documentation about what this layer actually does for 2D and 4D tensors, respectively.

Same comment applies to other norms.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to remove this explanation completly since it does not really belong into the code docs. Currently there are colab notebooks planned for better documentation and explanation. I am currently writing one for layer/group/instance normalization layers. If you want I can reference you in the PR when it is finished.

(typically the features axis).
For instance, after a `Conv2D` layer with
`data_format="channels_first"`,
set `axis=1` in `BatchNormalization`.
Smokrow marked this conversation as resolved.
Show resolved Hide resolved
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
Smokrow marked this conversation as resolved.
Show resolved Hide resolved
will be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
# Output shape
Same shape as input.
# References
- [Group Normalization](https://arxiv.org/abs/1803.08494)
"""

def __init__(self,
groups=32,
axis=-1,
epsilon=1e-5,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs):
super(GroupNormalization, self).__init__(**kwargs)
self.supports_masking = True
self.groups = groups
self.axis = axis
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)

def build(self, input_shape):
dim = input_shape[self.axis]

if dim is None:
raise ValueError('Axis ' + str(self.axis) + ' of '
'input tensor should have a defined dimension '
'but the layer received an input with shape ' +
str(input_shape) + '.')
if self.groups==-1:
self.groups=dim

if dim < self.groups:
raise ValueError('Number of groups (' + str(self.groups) + ') cannot be '
'more than the number of channels (' +
str(dim) + ').')

if dim % self.groups != 0:
raise ValueError('Number of groups (' + str(self.groups) + ') must be a '
'multiple of the number of channels (' +
str(dim) + ').')

self.input_spec = InputSpec(ndim=len(input_shape),
axes={self.axis: dim})
shape = (dim,)

if self.scale:
self.gamma = self.add_weight(shape=shape,
name='gamma',
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(shape=shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
self.built = True
super(GroupNormalization, self).build(input_shape)

def call(self, inputs):
input_shape = K.int_shape(inputs)
tensor_input_shape = K.shape(inputs)

# Prepare broadcasting shape.
reduction_axes = list(range(len(input_shape)))
Smokrow marked this conversation as resolved.
Show resolved Hide resolved
del reduction_axes[self.axis]
broadcast_shape = [1] * len(input_shape)
Smokrow marked this conversation as resolved.
Show resolved Hide resolved
broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
broadcast_shape.insert(1, self.groups)

reshape_group_shape = K.shape(inputs)
group_axes = [reshape_group_shape[i] for i in range(len(input_shape))]
group_axes[self.axis] = input_shape[self.axis] // self.groups
group_axes.insert(1, self.groups)
Smokrow marked this conversation as resolved.
Show resolved Hide resolved

# reshape inputs to new group shape
group_shape = [group_axes[0], self.groups] + group_axes[2:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it handle both channel-first and channel-last format?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. at this point the ordering would be [batch, group, channels, steps]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@facaiy since you are setting up an axis to work on I am not quite sure if you can handle "channel first" and "channel last". When somebody sets his axis exactly on the channel axis I think he should be allowed to do that.

group_shape = K.stack(group_shape)
inputs = K.reshape(inputs, group_shape)
Smokrow marked this conversation as resolved.
Show resolved Hide resolved

group_reduction_axes = list(range(len(group_axes)))
mean, variance = nn.moments(inputs, group_reduction_axes[2:],
keep_dims=True)
inputs = (inputs - mean) / (K.sqrt(variance + self.epsilon))
Smokrow marked this conversation as resolved.
Show resolved Hide resolved

# prepare broadcast shape
inputs = K.reshape(inputs, group_shape)
Smokrow marked this conversation as resolved.
Show resolved Hide resolved

outputs = inputs

# In this case we must explicitly broadcast all parameters.
if self.scale:
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
outputs = outputs * broadcast_gamma

if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
outputs = outputs + broadcast_beta

# finally we reshape the output back to the input shape
outputs = K.reshape(outputs, tensor_input_shape)

return outputs

def get_config(self):
config = {
'groups': self.groups,
'axis': self.axis,
'epsilon': self.epsilon,
'center': self.center,
'scale': self.scale,
'beta_initializer': initializers.serialize(self.beta_initializer),
'gamma_initializer': initializers.serialize(self.gamma_initializer),
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_constraint': constraints.serialize(self.gamma_constraint)
}
base_config = super(GroupNormalization, self).get_config()
seanpmorgan marked this conversation as resolved.
Show resolved Hide resolved
return dict(list(base_config.items()) + list(config.items()))

def compute_output_shape(self, input_shape):
return input_shape

class LayerNormalization(GroupNormalization):
"""Layer normalization layer.
Layer Normalization is an specific case of ```GroupNormalization```since it
normalizes all features of a layer. The Groupsize is 1.
Layer Normalization's computation is independent
of batch sizes, and its accuracy is stable in a wide range of batch sizes.
# Arguments
axis: Integer, the axis that should be normalized
(typically the features axis).
For instance, after a `Conv2D` layer with
`data_format="channels_first"`,
set `axis=1` in `BatchNormalization`.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
# Output shape
Same shape as input.
# References
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
"""
def __init__(self,**kwargs):
kwargs["groups"]=1
Smokrow marked this conversation as resolved.
Show resolved Hide resolved
super(LayerNormalization,self).__init__(**kwargs)

class InstanceNormalization(GroupNormalization):
"""Instance normalization layer.
Instance Normalization is an specific case of ```GroupNormalization```since it
normalizes all features of one channel. The Groupsize is equal to the channel size.
Instance Normalization's computation is independent
of batch sizes, and its accuracy is stable in a wide range of batch sizes.
# Arguments
axis: Integer, the axis that should be normalized
(typically the features axis).
For instance, after a `Conv2D` layer with
`data_format="channels_first"`,
set `axis=1` in `BatchNormalization`.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
# Output shape
Same shape as input.
# References
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
"""
def __init__(self,**kwargs):
kwargs["groups"]=-1
super(InstanceNormalization,self).__init__(**kwargs)
Loading