Skip to content

Commit

Permalink
Snake layer and activation to learn periodic functions (tensorflow#1967)
Browse files Browse the repository at this point in the history
* Snake layer and activation to learn periodic functions.

* Removed @tf.function

* Considering all the comments.

* Black matters :)

* Moved the doc under class Snake.
  • Loading branch information
failure-to-thrive authored and ashutosh1919 committed Jul 12, 2020
1 parent 90800d6 commit 196ec53
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
/tensorflow_addons/activations/tests/sparsemax_test.py @andreasmadsen
/tensorflow_addons/activations/tanhshrink.py @fsx950223
/tensorflow_addons/activations/tests/tanhshrink_test.py @fsx950223
/tensorflow_addons/activations/snake.py @failure-to-thrive
/tensorflow_addons/activations/tests/snake_test.py @failure-to-thrive

/tensorflow_addons/callbacks/average_model_checkpoint.py @squadrick
/tensorflow_addons/callbacks/time_stopping.py @shun-lin
Expand Down Expand Up @@ -99,6 +101,8 @@
/tensorflow_addons/layers/tests/wrappers_test.py @seanpmorgan
/tensorflow_addons/layers/esn.py @pedrolarben
/tensorflow_addons/layers/tests/esn_test.py @pedrolarben
/tensorflow_addons/layers/snake.py @failure-to-thrive
/tensorflow_addons/layers/tests/snake_test.py @failure-to-thrive

/tensorflow_addons/losses/contrastive.py @windqaq
/tensorflow_addons/losses/tests/contrastive_test.py @windqaq
Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
from tensorflow_addons.activations.mish import mish
from tensorflow_addons.activations.softshrink import softshrink
from tensorflow_addons.activations.rrelu import rrelu
from tensorflow_addons.activations.snake import snake
from tensorflow_addons.activations.sparsemax import sparsemax
from tensorflow_addons.activations.tanhshrink import tanhshrink
37 changes: 37 additions & 0 deletions tensorflow_addons/activations/snake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import tensorflow as tf

from tensorflow_addons.utils import types


@tf.keras.utils.register_keras_serializable(package="Addons")
def snake(logits: types.TensorLike, frequency: types.Number = 1) -> tf.Tensor:
"""Snake activation to learn periodic functions.
https://arxiv.org/abs/2006.08195
Args:
logits: Input tensor.
frequency: A scalar, frequency of the periodic part.
Returns:
Tensor of the same type and shape as `logits`.
"""

logits = tf.convert_to_tensor(logits)
frequency = tf.cast(frequency, logits.dtype)

return logits + (1 - tf.cos(2 * frequency * logits)) / (2 * frequency)
1 change: 1 addition & 0 deletions tensorflow_addons/activations/tests/activations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"softshrink",
"sparsemax",
"tanhshrink",
"snake",
]


Expand Down
29 changes: 29 additions & 0 deletions tensorflow_addons/activations/tests/snake_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import pytest

import numpy as np
from tensorflow_addons.activations import snake
from tensorflow_addons.utils import test_utils


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
def test_activation(dtype):
x = dtype(np.random.rand(2, 5))
a = dtype(np.random.randn())
expected_result = x + np.power(np.sin(a * x), 2) / a
test_utils.assert_allclose_according_to_type(snake(x, a), expected_result)
1 change: 1 addition & 0 deletions tensorflow_addons/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tensorflow_addons.layers.optical_flow import CorrelationCost
from tensorflow_addons.layers.poincare import PoincareNormalize
from tensorflow_addons.layers.polynomial import PolynomialCrossing
from tensorflow_addons.layers.snake import Snake
from tensorflow_addons.layers.sparsemax import Sparsemax
from tensorflow_addons.layers.spatial_pyramid_pooling import SpatialPyramidPooling2D
from tensorflow_addons.layers.tlu import TLU
Expand Down
53 changes: 53 additions & 0 deletions tensorflow_addons/layers/snake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements Snake layer."""

import tensorflow as tf
from typeguard import typechecked

from tensorflow_addons.activations.snake import snake

from tensorflow_addons.utils import types


@tf.keras.utils.register_keras_serializable(package="Addons")
class Snake(tf.keras.layers.Layer):
"""Snake layer to learn periodic functions with the trainable `frequency` scalar.
https://arxiv.org/abs/2006.08195
Arguments:
frequency_initializer: Initializer for the `frequency` scalar.
"""

@typechecked
def __init__(self, frequency_initializer: types.Initializer = "ones", **kwargs):
super().__init__(**kwargs)
self.frequency_initializer = tf.keras.initializers.get(frequency_initializer)
self.frequency = self.add_weight(
initializer=frequency_initializer, trainable=True
)

def call(self, inputs):
return snake(inputs, self.frequency)

def get_config(self):
config = {
"frequency_initializer": tf.keras.initializers.serialize(
self.frequency_initializer
),
}
base_config = super().get_config()
return {**base_config, **config}
39 changes: 39 additions & 0 deletions tensorflow_addons/layers/tests/snake_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Snake layer."""

import pytest

import numpy as np
import tensorflow as tf

from tensorflow_addons.layers.snake import Snake
from tensorflow_addons.activations.snake import snake

from tensorflow_addons.utils import test_utils


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
def test_layer(dtype):
x = np.random.rand(2, 5).astype(dtype)
a = np.random.randn()
val = snake(x, a)
test_utils.layer_test(
Snake,
kwargs={"frequency_initializer": tf.constant_initializer(a), "dtype": dtype},
input_data=x,
expected_output=val,
)

0 comments on commit 196ec53

Please sign in to comment.