-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathnormalization.py
188 lines (179 loc) · 8.44 KB
/
normalization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from six.moves import xrange # pylint: disable=redefined-builtin
import numpy as np
from tensorflow.python.layers import base
from tensorflow.python.ops import init_ops
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import initializers
from tensorflow.python.keras.engine import InputSpec
from tensorflow.python.keras.engine import Layer
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util.tf_export import tf_export
from tensorflow.contrib.framework.python.ops import add_arg_scope
from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.layers.python.layers import initializers
from tensorflow.contrib.layers.python.layers import utils
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base
from tensorflow.python.layers import convolutional as convolutional_layers
from tensorflow.python.layers import core as core_layers
from tensorflow.python.layers import normalization as normalization_layers
from tensorflow.python.layers import pooling as pooling_layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import standard_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training import moving_averages
@add_arg_scope
def custom_norm(inputs,
center=True,
scale=True,
activation_fn=None,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
begin_norm_axis=1,
begin_params_axis=-1,
scope=None):
"""Adds a Layer Normalization layer.
Based on the paper:
"Layer Normalization"
Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
https://arxiv.org/abs/1607.06450.
Can be used as a normalizer function for conv2d and fully_connected.
Given a tensor `inputs` of rank `R`, moments are calculated and normalization
is performed over axes `begin_norm_axis ... R - 1`. Scaling and centering,
if requested, is performed over axes `begin_params_axis .. R - 1`.
By default, `begin_norm_axis = 1` and `begin_params_axis = -1`,
meaning that normalization is performed over all but the first axis
(the `HWC` if `inputs` is `NHWC`), while the `beta` and `gamma` trainable
parameters are calculated for the rightmost axis (the `C` if `inputs` is
`NHWC`). Scaling and recentering is performed via broadcast of the
`beta` and `gamma` parameters with the normalized tensor.
The shapes of `beta` and `gamma` are `inputs.shape[begin_params_axis:]`,
and this part of the inputs' shape must be fully defined.
Args:
inputs: A tensor having rank `R`. The normalization is performed over
axes `begin_norm_axis ... R - 1` and centering and scaling parameters
are calculated over `begin_params_axis ... R - 1`.
center: If True, add offset of `beta` to normalized tensor. If False, `beta`
is ignored.
scale: If True, multiply by `gamma`. If False, `gamma` is
not used. When the next layer is linear (also e.g. `nn.relu`), this can be
disabled since the scaling can be done by the next layer.
activation_fn: Activation function, default set to None to skip it and
maintain a linear activation.
reuse: Whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
variables_collections: Optional collections for the variables.
outputs_collections: Collections to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
begin_norm_axis: The first normalization dimension: normalization will be
performed along dimensions `begin_norm_axis : rank(inputs)`
begin_params_axis: The first parameter (beta, gamma) dimension: scale
and centering parameters will have dimensions
`begin_params_axis : rank(inputs)` and will be broadcast with the
normalized inputs accordingly.
scope: Optional scope for `variable_scope`.
Returns:
A `Tensor` representing the output of the operation, having the same
shape and dtype as `inputs`.
Raises:
ValueError: If the rank of `inputs` is not known at graph build time,
or if `inputs.shape[begin_params_axis:]` is not fully defined at
graph build time.
"""
with variable_scope.variable_scope(
scope, 'LayerNorm', [inputs], reuse=reuse) as sc:
inputs = ops.convert_to_tensor(inputs)
inputs_shape = inputs.shape
inputs_rank = inputs_shape.ndims
if inputs_rank is None:
raise ValueError('Inputs %s has undefined rank.' % inputs.name)
dtype = inputs.dtype.base_dtype
if begin_norm_axis < 0:
begin_norm_axis = inputs_rank + begin_norm_axis
if begin_params_axis >= inputs_rank or begin_norm_axis >= inputs_rank:
raise ValueError('begin_params_axis (%d) and begin_norm_axis (%d) '
'must be < rank(inputs) (%d)' %
(begin_params_axis, begin_norm_axis, inputs_rank))
params_shape = inputs_shape[begin_params_axis:]
if not params_shape.is_fully_defined():
raise ValueError(
'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' %
(inputs.name, begin_params_axis, inputs_shape))
# Allocate parameters for the beta and gamma of the normalization.
beta, gamma = None, None
if center:
beta_collections = utils.get_variable_collections(variables_collections,
'beta')
beta = variables.model_variable(
'beta',
shape=params_shape,
dtype=dtype,
initializer=init_ops.zeros_initializer(),
collections=beta_collections,
trainable=trainable)
if scale:
gamma_collections = utils.get_variable_collections(
variables_collections, 'gamma')
gamma = variables.model_variable(
'gamma',
shape=params_shape,
dtype=dtype,
initializer=init_ops.ones_initializer(),
collections=gamma_collections,
trainable=trainable)
# Calculate the moments on the last axis (layer activations).
norm_axes = list(range(begin_norm_axis, inputs_rank))
mean0, variance0 = nn.moments(inputs, norm_axes, keep_dims=True)
mean = 0
inv = 1./(mean0 + 1e-6)
# inv = 1. / (mean0 + 0.2)
# Compute layer normalization using the batch_normalization function.
variance_epsilon = 1e-12
x = inputs
offset = beta
scale = gamma
# inv = math_ops.rsqrt(variance + variance_epsilon)
with ops.name_scope(None, "batchnorm", [x, mean, inv, scale, offset]):
if scale is not None:
inv *= scale
# Note: tensorflow/contrib/quantize/python/fold_batch_norms.py depends on
# the precise order of ops that are generated by the expression below.
return x * math_ops.cast(inv, x.dtype) + math_ops.cast(
offset - mean * inv if offset is not None else -mean * inv,
x.dtype)
outputs.set_shape(inputs_shape)
if activation_fn is not None:
outputs = activation_fn(outputs)
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)