Skip to content

Commit

Permalink
Spectral Normalization continued (#1244)
Browse files Browse the repository at this point in the history
* update code and email author

* change input_spec to include batch dim

* add pytest tmpdir signature

* add suggestions

* modify save_and_load test

* add black format

* convert LocalPath to str
  • Loading branch information
charlielito authored Aug 4, 2020
1 parent 40e406d commit 151e2f7
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@
/tensorflow_addons/layers/tests/polynomial_test.py @tanzhenyu
/tensorflow_addons/layers/sparsemax.py @andreasmadsen
/tensorflow_addons/layers/tests/sparsemax_test.py @andreasmadsen
/tensorflow_addons/layers/spectral_normalization.py @charlielito
/tensorflow_addons/layers/tests/spectral_normalization_test.py @charlielito
/tensorflow_addons/layers/spatial_pyramid_pooling.py @Susmit-A
/tensorflow_addons/layers/tests/spatial_pyramid_pooling_test.py @Susmit-A
/tensorflow_addons/layers/tlu.py @aakashkumarnain
Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from tensorflow_addons.layers.polynomial import PolynomialCrossing
from tensorflow_addons.layers.snake import Snake
from tensorflow_addons.layers.sparsemax import Sparsemax
from tensorflow_addons.layers.spectral_normalization import SpectralNormalization
from tensorflow_addons.layers.spatial_pyramid_pooling import SpatialPyramidPooling2D
from tensorflow_addons.layers.tlu import TLU
from tensorflow_addons.layers.wrappers import WeightNormalization
Expand Down
122 changes: 122 additions & 0 deletions tensorflow_addons/layers/spectral_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import tensorflow as tf
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
class SpectralNormalization(tf.keras.layers.Wrapper):
"""This wrapper controls the Lipschitz constant of the layer by
constraining its spectral norm.
This stabilizes the training of GANs.
Spectral Normalization for Generative Adversarial Networks:
https://arxiv.org/abs/1802.05957
Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida (2018)
SpectralNormalization wrapper works for keras and tf layers.
```python
net = SpectralNormalization(
tf.keras.layers.Conv2D(2, 2, activation="relu"),
input_shape=(32, 32, 3))(x)
net = SpectralNormalization(
tf.keras.layers.Conv2D(16, 5, activation="relu"))(net)
net = SpectralNormalization(
tf.keras.layers.Dense(120, activation="relu"))(net)
net = SpectralNormalization(
tf.keras.layers.Dense(n_classes))(net)
```
Arguments:
layer: a layer instance.
Raises:
AssertionError: If not initialized with a `Layer` instance.
ValueError: If initialized with negative `power_iterations`
AttributeError: If `Layer` does not contain a `kernel` or `embeddings` of weights
"""

@typechecked
def __init__(self, layer: tf.keras.layers, power_iterations: int = 1, **kwargs):
super().__init__(layer, **kwargs)
if power_iterations <= 0:
raise ValueError(
"`power_iterations` should be greater than zero, got "
"`power_iterations={}`".format(power_iterations)
)
self.power_iterations = power_iterations
self._initialized = False

def build(self, input_shape):
"""Build `Layer`"""
super().build(input_shape)
input_shape = tf.TensorShape(input_shape)
self.input_spec = tf.keras.layers.InputSpec(shape=[None] + input_shape[1:])

if hasattr(self.layer, "kernel"):
self.w = self.layer.kernel
elif hasattr(self.layer, "embeddings"):
self.w = self.layer.embeddings
else:
raise AttributeError(
"{} object has no attribute 'kernel' nor "
"'embeddings'".format(type(self.layer).__name__)
)

self.w_shape = self.w.shape.as_list()

self.u = self.add_weight(
shape=(1, self.w_shape[-1]),
initializer=tf.initializers.TruncatedNormal(stddev=0.02),
trainable=False,
name="sn_u",
dtype=self.w.dtype,
)

def call(self, inputs, training=None):
"""Call `Layer`"""
if training is None:
training = tf.keras.backend.learning_phase()

if training:
self.normalize_weights()

output = self.layer(inputs)
return output

def compute_output_shape(self, input_shape):
return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list())

@tf.function
def normalize_weights(self):
"""Generate spectral normalized weights.
This method will update the value of self.w with the
spectral normalized value, so that the layer is ready for call().
"""

w = tf.reshape(self.w, [-1, self.w_shape[-1]])
u = self.u

with tf.name_scope("spectral_normalize"):
for _ in range(self.power_iterations):
v = tf.math.l2_normalize(tf.matmul(u, w, transpose_b=True))
u = tf.math.l2_normalize(tf.matmul(v, w))

sigma = tf.matmul(tf.matmul(v, w), u, transpose_b=True)

self.w.assign(self.w / sigma)
self.u.assign(u)

def get_config(self):
config = {"power_iterations": self.power_iterations}
base_config = super().get_config()
return {**base_config, **config}
169 changes: 169 additions & 0 deletions tensorflow_addons/layers/tests/spectral_normalization_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import numpy as np
import pytest
import tensorflow as tf

from tensorflow_addons.layers import spectral_normalization
from tensorflow_addons.utils import test_utils


def test_keras():
input_data = np.random.random((10, 3, 4)).astype(np.float32)
test_utils.layer_test(
spectral_normalization.SpectralNormalization,
kwargs={"layer": tf.keras.layers.Dense(2), "input_shape": (3, 4)},
input_data=input_data,
)


def test_from_to_config():
base_layer = tf.keras.layers.Dense(1)
sn = spectral_normalization.SpectralNormalization(base_layer)
config = sn.get_config()

new_sn = spectral_normalization.SpectralNormalization.from_config(config)
assert sn.power_iterations == new_sn.power_iterations


def test_save_load_model(tmpdir):
base_layer = tf.keras.layers.Dense(1)
input_shape = [1]

inputs = tf.keras.layers.Input(shape=input_shape)
sn_layer = spectral_normalization.SpectralNormalization(base_layer)
model = tf.keras.models.Sequential(layers=[inputs, sn_layer])

# initialize model
model.predict(np.random.uniform(size=(2, 1)))

model_path = str(tmpdir / "test.h5")
model.save(model_path)
new_model = tf.keras.models.load_model(model_path)

assert model.layers[0].get_config() == new_model.layers[0].get_config()


@pytest.mark.parametrize(
"base_layer_fn, input_shape, output_shape",
[
(lambda: tf.keras.layers.Dense(2), [3, 2], [3, 2]),
(
lambda: tf.keras.layers.Conv2D(3, (2, 2), padding="same"),
[4, 4, 3],
[4, 4, 3],
),
(lambda: tf.keras.layers.Embedding(2, 10), [2], [2, 10]),
],
)
def test_model_fit(base_layer_fn, input_shape, output_shape):
inputs = tf.keras.layers.Input(shape=input_shape)
base_layer = base_layer_fn()

sn_layer = spectral_normalization.SpectralNormalization(base_layer)
model = tf.keras.models.Sequential(layers=[inputs, sn_layer])
model.add(tf.keras.layers.Activation("relu"))

model.compile(
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001), loss="mse"
)
model.fit(
np.random.random((2, *input_shape)),
np.random.random((2, *output_shape)),
epochs=3,
batch_size=10,
verbose=0,
)
assert hasattr(model.layers[0], "u")


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize(
"base_layer_fn, input_shape",
[
(lambda: tf.keras.layers.Dense(2), [3, 2]),
(lambda: tf.keras.layers.Conv2D(3, (2, 2), padding="same"), [4, 4, 3],),
(lambda: tf.keras.layers.Embedding(2, 10), [2]),
],
)
def test_model_build(base_layer_fn, input_shape):
inputs = tf.keras.layers.Input(shape=input_shape)
base_layer = base_layer_fn()
sn_layer = spectral_normalization.SpectralNormalization(base_layer)
model = tf.keras.models.Sequential(layers=[inputs, sn_layer])
model.build()
assert hasattr(model.layers[0], "u")


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_normalization():
inputs = tf.keras.layers.Input(shape=[2, 2, 1])

base_layer = tf.keras.layers.Conv2D(
1, (2, 2), kernel_initializer=tf.constant_initializer(value=2),
)
sn_layer = spectral_normalization.SpectralNormalization(base_layer)
model = tf.keras.models.Sequential(layers=[inputs, sn_layer])

weights = np.squeeze(model.layers[0].w.numpy())
# This wrapper normalizes weights by the maximum eigen value
eigen_val, _ = np.linalg.eig(weights)
weights_normalized = weights / np.max(eigen_val)

for training in [False, True]:
_ = model(
tf.constant(np.ones((1, 2, 2, 1), dtype=np.float32)), training=training,
)
if training:
w = weights_normalized
else:
w = weights
np.testing.assert_allclose(w, np.squeeze(model.layers[0].w.numpy()))


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_apply_layer():
images = tf.ones((1, 2, 2, 1))
sn_wrapper = spectral_normalization.SpectralNormalization(
tf.keras.layers.Conv2D(
1, [2, 2], kernel_initializer=tf.constant_initializer(value=1)
),
input_shape=(2, 2, 1),
)

result = sn_wrapper(images, training=False)
result_train = sn_wrapper(images, training=True)
expected_output = np.array([[[[4.0]]]], dtype=np.float32)

np.testing.assert_allclose(result, expected_output)
# max eigen value of 2x2 matrix of ones is 2
np.testing.assert_allclose(result_train, expected_output / 2)
assert hasattr(sn_wrapper, "u")


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_no_layer():
images = tf.random.uniform((2, 4, 43))
with pytest.raises(AssertionError):
spectral_normalization.SpectralNormalization(images)


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_no_kernel():
with pytest.raises(AttributeError):
spectral_normalization.SpectralNormalization(
tf.keras.layers.MaxPooling2D(2, 2)
).build((2, 2))

0 comments on commit 151e2f7

Please sign in to comment.