-
Notifications
You must be signed in to change notification settings - Fork 613
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add lisht kernel * update README * format code * fix tolerance * reorder the computation * unify namespace * clean up testcase * format code * fix typo * fix namespace comment * remove extra the * change test size to small
- Loading branch information
Showing
10 changed files
with
433 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,25 +4,26 @@ | |
| Submodule | Maintainers | Contact Info | | ||
|:----------|:--------------------------|:-----------------------------------------| | ||
| gelu | @AakashKumarNain @WindQAQ | [email protected] [email protected] | | ||
| hardshrink| @WindQAQ | [email protected] | ||
| hardshrink| @WindQAQ | [email protected] | | ||
| lisht | @WindQAQ | [email protected] | | ||
| sparsemax | @AndreasMadsen | [email protected] | | ||
| tanhshrink | @fsx950223 | [email protected] | | ||
| tanhshrink| @fsx950223 | [email protected] | | ||
|
||
## Contents | ||
| Submodule | Activation | Reference | | ||
|:----------|:-----------|:---------------------------------| | ||
| gelu | gelu | https://arxiv.org/abs/1606.08415 | | ||
| hardshrink| hardshrink | | | ||
| sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 | | ||
| tanhshrink | Tanhshrink | | | ||
| lisht | lisht | https://arxiv.org/abs/1901.05894 | | ||
| sparsemax | sparsemax | https://arxiv.org/abs/1602.02068 | | ||
| tanhshrink| tanhshrink | | | ||
|
||
|
||
## Contribution Guidelines | ||
#### Standard API | ||
In order to conform with the current API standard, all activations | ||
must: | ||
* Be a `tf.function`. | ||
* Have the signature `fn(input, axis=-1, name=None)`. | ||
* [Register as a keras global object](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/utils/python/keras_utils.py) | ||
so it can be serialized properly. | ||
* Add the addon to the `py_library` in this sub-package's BUILD file. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf | ||
from tensorflow_addons.utils import keras_utils | ||
from tensorflow_addons.utils.resource_loader import get_path_to_datafile | ||
|
||
_activation_ops_so = tf.load_op_library( | ||
get_path_to_datafile("custom_ops/activations/_activation_ops.so")) | ||
|
||
|
||
@keras_utils.register_keras_custom_object | ||
@tf.function | ||
def lisht(x): | ||
"""LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function. | ||
Computes linearly scaled hyperbolic tangent (LiSHT): `x * tanh(x)` | ||
See [LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function for Neural Networks](https://arxiv.org/abs/1901.05894). | ||
Args: | ||
x: A `Tensor`. Must be one of the following types: | ||
`float16`, `float32`, `float64`. | ||
Returns: | ||
A `Tensor`. Has the same type as `x`. | ||
""" | ||
x = tf.convert_to_tensor(x) | ||
return _activation_ops_so.addons_lisht(x) | ||
|
||
|
||
@tf.RegisterGradient("Addons>Lisht") | ||
def _lisht_grad(op, grad): | ||
return _activation_ops_so.addons_lisht_grad(grad, op.inputs[0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from absl.testing import parameterized | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow_addons.activations import lisht | ||
from tensorflow_addons.utils import test_utils | ||
|
||
|
||
@test_utils.run_all_in_graph_and_eager_modes | ||
class LishtTest(tf.test.TestCase, parameterized.TestCase): | ||
@parameterized.named_parameters(("float16", np.float16), | ||
("float32", np.float32), | ||
("float64", np.float64)) | ||
def test_lisht(self, dtype): | ||
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) | ||
expected_result = tf.constant( | ||
[1.9280552, 0.7615942, 0.0, 0.7615942, 1.9280552], dtype=dtype) | ||
self.assertAllCloseAccordingToType(lisht(x), expected_result) | ||
|
||
@parameterized.named_parameters(("float32", np.float32), | ||
("float64", np.float64)) | ||
def test_theoretical_gradients(self, dtype): | ||
# Only test theoretical gradients for float32 and float64 | ||
# because of the instability of float16 while computing jacobian | ||
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) | ||
|
||
theoretical, numerical = tf.test.compute_gradient(lisht, [x]) | ||
self.assertAllCloseAccordingToType( | ||
theoretical, numerical, rtol=5e-4, atol=5e-4) | ||
|
||
def test_unknown_shape(self): | ||
fn = lisht.get_concrete_function( | ||
tf.TensorSpec(shape=None, dtype=tf.float32)) | ||
|
||
for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]: | ||
x = tf.ones(shape=shape, dtype=tf.float32) | ||
self.assertAllClose(fn(x), lisht(x)) | ||
|
||
def test_serialization(self): | ||
config = tf.keras.activations.serialize(lisht) | ||
fn = tf.keras.activations.deserialize(config) | ||
self.assertEqual(fn, lisht) | ||
|
||
def test_serialization_with_layers(self): | ||
layer = tf.keras.layers.Dense(3, activation=lisht) | ||
config = tf.keras.layers.serialize(layer) | ||
deserialized_layer = tf.keras.layers.deserialize(config) | ||
self.assertEqual(deserialized_layer.__class__.__name__, | ||
layer.__class__.__name__) | ||
self.assertEqual(deserialized_layer.activation.__name__, "lisht") | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
79 changes: 79 additions & 0 deletions
79
tensorflow_addons/custom_ops/activations/cc/kernels/lisht_op.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#define EIGEN_USE_THREADS | ||
|
||
#include "tensorflow_addons/custom_ops/activations/cc/kernels/lisht_op.h" | ||
#include "tensorflow/core/framework/op_kernel.h" | ||
#include "tensorflow/core/framework/register_types.h" | ||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" | ||
|
||
namespace tensorflow { | ||
namespace addons { | ||
|
||
using CPUDevice = Eigen::ThreadPoolDevice; | ||
|
||
#define REGISTER_LISHT_KERNELS(type) \ | ||
REGISTER_KERNEL_BUILDER( \ | ||
Name("Addons>Lisht").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ | ||
LishtOp<CPUDevice, type>); \ | ||
REGISTER_KERNEL_BUILDER( \ | ||
Name("Addons>LishtGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ | ||
LishtGradOp<CPUDevice, type>); | ||
|
||
// Lisht only makes sense with floating points. | ||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_LISHT_KERNELS); | ||
#undef REGISTER_LISHT_KERNELS | ||
|
||
#if GOOGLE_CUDA | ||
|
||
using GPUDevice = Eigen::GpuDevice; | ||
|
||
// Forward declarations of the functor specializations for GPU. | ||
namespace functor { | ||
#define DECLARE_GPU_SPEC(T) \ | ||
template <> \ | ||
void Lisht<GPUDevice, T>::operator()( \ | ||
const GPUDevice& d, typename TTypes<T>::ConstTensor features, \ | ||
typename TTypes<T>::Tensor activations); \ | ||
extern template struct Lisht<GPUDevice, T>; \ | ||
\ | ||
template <> \ | ||
void LishtGrad<GPUDevice, T>::operator()( \ | ||
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \ | ||
typename TTypes<T>::ConstTensor features, \ | ||
typename TTypes<T>::Tensor backprops); \ | ||
extern template struct LishtGrad<GPUDevice, T>; | ||
|
||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); | ||
#undef DECLARE_GPU_SPEC | ||
} // namespace functor | ||
|
||
// Registration of the GPU implementations. | ||
#define REGISTER_LISHT_GPU_KERNELS(type) \ | ||
REGISTER_KERNEL_BUILDER( \ | ||
Name("Addons>Lisht").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ | ||
LishtOp<GPUDevice, type>); \ | ||
REGISTER_KERNEL_BUILDER( \ | ||
Name("Addons>LishtGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ | ||
LishtGradOp<GPUDevice, type>); | ||
|
||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_LISHT_GPU_KERNELS); | ||
#undef REGISTER_LISHT_GPU_KERNELS | ||
|
||
#endif // GOOGLE_CUDA | ||
|
||
} // namespace addons | ||
} // namespace tensorflow |
Oops, something went wrong.