Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WeightNorm support for RNNs #769

Merged
merged 1 commit into from
Dec 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions tensorflow_addons/layers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,31 +58,34 @@ def __init__(self, layer, data_init=True, **kwargs):
super(WeightNormalization, self).__init__(layer, **kwargs)
self.data_init = data_init
self._track_trackable(layer, name='layer')
self.is_rnn = isinstance(self.layer, tf.keras.layers.RNN)

def build(self, input_shape):
"""Build `Layer`"""
input_shape = tf.TensorShape(input_shape).as_list()
input_shape = tf.TensorShape(input_shape)
self.input_spec = tf.keras.layers.InputSpec(
shape=[None] + input_shape[1:])

if not self.layer.built:
self.layer.build(input_shape)

if not hasattr(self.layer, 'kernel'):
kernel_layer = self.layer.cell if self.is_rnn else self.layer

if not hasattr(kernel_layer, 'kernel'):
raise ValueError('`WeightNormalization` must wrap a layer that'
' contains a `kernel` for weights')

# The kernel's filter or unit dimension is -1
self.layer_depth = int(self.layer.kernel.shape[-1])
self.kernel_norm_axes = list(range(self.layer.kernel.shape.rank - 1))
self.layer_depth = int(kernel_layer.kernel.shape[-1])
self.kernel_norm_axes = list(range(kernel_layer.kernel.shape.rank - 1))

self.g = self.add_weight(
name='g',
shape=(self.layer_depth,),
initializer='ones',
dtype=self.layer.kernel.dtype,
dtype=kernel_layer.kernel.dtype,
trainable=True)
self.v = self.layer.kernel
self.v = kernel_layer.kernel

self._initialized = self.add_weight(
name='initialized',
Expand All @@ -100,7 +103,10 @@ def build(self, input_shape):
layer_config)
self._naked_clone_layer.build(input_shape)
self._naked_clone_layer.set_weights(self.layer.get_weights())
self._naked_clone_layer.activation = None
if self.is_rnn:
self._naked_clone_layer.cell.activation = None
else:
self._naked_clone_layer.activation = None

self.built = True

Expand Down
7 changes: 7 additions & 0 deletions tensorflow_addons/layers/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ def test_weightnorm_with_time_dist(self):
out = tf.keras.layers.TimeDistributed(b)(inputs)
model = tf.keras.Model(inputs, out)

def test_weightnorm_with_rnn(self):
inputs = tf.keras.layers.Input(shape=(None, 3))
rnn_layer = tf.keras.layers.SimpleRNN(4)
wt_rnn = wrappers.WeightNormalization(rnn_layer)
dense = tf.keras.layers.Dense(1)
model = tf.keras.models.Sequential(layers=[inputs, wt_rnn, dense])

def test_save_file_h5(self):
self.create_tempfile('wrapper_test_model.h5')
conv = tf.keras.layers.Conv1D(1, 1)
Expand Down