Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GroupNormalization #611

Merged
merged 2 commits into from
Dec 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions tensorflow_addons/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,19 +161,20 @@ def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape):

group_shape = [tensor_input_shape[i] for i in range(len(input_shape))]
group_shape[self.axis] = input_shape[self.axis] // self.groups
group_shape.insert(1, self.groups)
group_shape.insert(self.axis, self.groups)
group_shape = tf.stack(group_shape)
reshaped_inputs = tf.reshape(inputs, group_shape)
return reshaped_inputs, group_shape

def _apply_normalization(self, reshaped_inputs, input_shape):

group_shape = tf.keras.backend.int_shape(reshaped_inputs)
group_reduction_axes = list(range(len(group_shape)))
# Remember the ordering of the tensor is [batch, group , steps]. Jump
# the first 2 to calculate the variance and the mean
group_reduction_axes = list(range(1, len(group_shape)))
axis = -2 if self.axis == -1 else self.axis - 1
group_reduction_axes.pop(axis)

mean, variance = tf.nn.moments(
reshaped_inputs, group_reduction_axes[2:], keepdims=True)
reshaped_inputs, group_reduction_axes, keepdims=True)

gamma, beta = self._get_reshaped_weights(input_shape)
normalized_inputs = tf.nn.batch_normalization(
Expand Down Expand Up @@ -269,7 +270,7 @@ def _add_beta_weight(self, input_shape):
def _create_broadcast_shape(self, input_shape):
broadcast_shape = [1] * len(input_shape)
broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
broadcast_shape.insert(1, self.groups)
broadcast_shape.insert(self.axis, self.groups)
return broadcast_shape


Expand Down
16 changes: 8 additions & 8 deletions tensorflow_addons/layers/normalizations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def run_reshape_test(axis, group, input_shape, expected_shape):
self.assertEqual(int(group_shape[i]), expected_shape[i])

input_shape = (10, 10, 10)
expected_shape = [10, 5, 10, 2]
expected_shape = [10, 10, 5, 2]
run_reshape_test(2, 5, input_shape, expected_shape)

input_shape = (10, 10, 10)
Expand Down Expand Up @@ -108,18 +108,18 @@ def _test_specific_layer(self, inputs, axis, groups, center, scale):
np_inputs = inputs.numpy()
reshaped_dims = list(np_inputs.shape)
reshaped_dims[axis] = reshaped_dims[axis] // groups
reshaped_dims.insert(1, groups)
reshaped_dims.insert(axis, groups)
reshaped_inputs = np.reshape(np_inputs, tuple(reshaped_dims))

group_reduction_axes = list(range(1, len(reshaped_dims)))
axis = -2 if axis == -1 else axis - 1
group_reduction_axes.pop(axis)

# Calculate mean and variance
mean = np.mean(
reshaped_inputs,
axis=tuple(range(2, len(reshaped_dims))),
keepdims=True)
reshaped_inputs, axis=tuple(group_reduction_axes), keepdims=True)
variance = np.var(
reshaped_inputs,
axis=tuple(range(2, len(reshaped_dims))),
keepdims=True)
reshaped_inputs, axis=tuple(group_reduction_axes), keepdims=True)

# Get gamma and beta initalized by layer
gamma, beta = layer._get_reshaped_weights(input_shape)
Expand Down