Skip to content

Commit

Permalink
Fix GroupNormalization (#611)
Browse files Browse the repository at this point in the history
* Fix a bug, where the GroupNormalization layer was normalizing over the second axis instead of the selected axis.
* Update tests (which seem to be irrelevant anyway)
* Lint
  • Loading branch information
DawyD authored and seanpmorgan committed Dec 2, 2019
1 parent 91e9515 commit ce1e230
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
13 changes: 7 additions & 6 deletions tensorflow_addons/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,20 @@ def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape):

group_shape = [tensor_input_shape[i] for i in range(len(input_shape))]
group_shape[self.axis] = input_shape[self.axis] // self.groups
group_shape.insert(1, self.groups)
group_shape.insert(self.axis, self.groups)
group_shape = tf.stack(group_shape)
reshaped_inputs = tf.reshape(inputs, group_shape)
return reshaped_inputs, group_shape

def _apply_normalization(self, reshaped_inputs, input_shape):

group_shape = tf.keras.backend.int_shape(reshaped_inputs)
group_reduction_axes = list(range(len(group_shape)))
# Remember the ordering of the tensor is [batch, group , steps]. Jump
# the first 2 to calculate the variance and the mean
group_reduction_axes = list(range(1, len(group_shape)))
axis = -2 if self.axis == -1 else self.axis - 1
group_reduction_axes.pop(axis)

mean, variance = tf.nn.moments(
reshaped_inputs, group_reduction_axes[2:], keepdims=True)
reshaped_inputs, group_reduction_axes, keepdims=True)

gamma, beta = self._get_reshaped_weights(input_shape)
normalized_inputs = tf.nn.batch_normalization(
Expand Down Expand Up @@ -268,7 +269,7 @@ def _add_beta_weight(self, input_shape):
def _create_broadcast_shape(self, input_shape):
broadcast_shape = [1] * len(input_shape)
broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
broadcast_shape.insert(1, self.groups)
broadcast_shape.insert(self.axis, self.groups)
return broadcast_shape


Expand Down
16 changes: 8 additions & 8 deletions tensorflow_addons/layers/normalizations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def run_reshape_test(axis, group, input_shape, expected_shape):
self.evaluate(group_shape[i]), expected_shape[i])

input_shape = (10, 10, 10)
expected_shape = [10, 5, 10, 2]
expected_shape = [10, 10, 5, 2]
run_reshape_test(2, 5, input_shape, expected_shape)

input_shape = (10, 10, 10)
Expand Down Expand Up @@ -110,18 +110,18 @@ def _test_specific_layer(self, inputs, axis, groups, center, scale):
np_inputs = self.evaluate(inputs)
reshaped_dims = list(np_inputs.shape)
reshaped_dims[axis] = reshaped_dims[axis] // groups
reshaped_dims.insert(1, groups)
reshaped_dims.insert(axis, groups)
reshaped_inputs = np.reshape(np_inputs, tuple(reshaped_dims))

group_reduction_axes = list(range(1, len(reshaped_dims)))
axis = -2 if axis == -1 else axis - 1
group_reduction_axes.pop(axis)

# Calculate mean and variance
mean = np.mean(
reshaped_inputs,
axis=tuple(range(2, len(reshaped_dims))),
keepdims=True)
reshaped_inputs, axis=tuple(group_reduction_axes), keepdims=True)
variance = np.var(
reshaped_inputs,
axis=tuple(range(2, len(reshaped_dims))),
keepdims=True)
reshaped_inputs, axis=tuple(group_reduction_axes), keepdims=True)

# Get gamma and beta initalized by layer
gamma, beta = layer._get_reshaped_weights(input_shape)
Expand Down

0 comments on commit ce1e230

Please sign in to comment.