Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add lisht kernel #529

Merged
merged 12 commits into from
Sep 30, 2019
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tensorflow_addons/activations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ py_library(
"__init__.py",
"gelu.py",
"hardshrink.py",
"lisht.py",
"sparsemax.py",
"tanhshrink.py",
],
Expand Down Expand Up @@ -57,6 +58,19 @@ py_test(
],
)

py_test(
name = "lisht_test",
size = "medium",
WindQAQ marked this conversation as resolved.
Show resolved Hide resolved
srcs = [
"lisht_test.py",
],
main = "lisht_test.py",
srcs_version = "PY2AND3",
deps = [
":activations",
],
)

py_test(
name = "tanhshrink_test",
size = "medium",
Expand Down
11 changes: 6 additions & 5 deletions tensorflow_addons/activations/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,26 @@
| Submodule | Maintainers | Contact Info |
|:----------|:--------------------------|:-----------------------------------------|
| gelu | @AakashKumarNain @WindQAQ | [email protected] [email protected] |
| hardshrink| @WindQAQ | [email protected]
| hardshrink| @WindQAQ | [email protected] |
| lisht | @WindQAQ | [email protected] |
| sparsemax | @AndreasMadsen | [email protected] |
| tanhshrink | @fsx950223 | [email protected] |
| tanhshrink| @fsx950223 | [email protected] |

## Contents
| Submodule | Activation | Reference |
|:----------|:-----------|:---------------------------------|
| gelu | gelu | https://arxiv.org/abs/1606.08415 |
| hardshrink| hardshrink | |
| sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 |
| tanhshrink | Tanhshrink | |
| lisht | lisht | https://arxiv.org/abs/1901.05894 |
| sparsemax | sparsemax | https://arxiv.org/abs/1602.02068 |
| tanhshrink| tanhshrink | |


## Contribution Guidelines
#### Standard API
In order to conform with the current API standard, all activations
must:
* Be a `tf.function`.
* Have the signature `fn(input, axis=-1, name=None)`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed in this PR, but wondering if we should remove the name parameter from sparsemax activation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to remove it (at least the style is consistent in the same submodule). But not sure if other operations in image and text etc should keep it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to keep name in other modules :-)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will create another PR to clean up activation functions (name arg, duplicated tests etc.)

* [Register as a keras global object](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/utils/python/keras_utils.py)
so it can be serialized properly.
* Add the addon to the `py_library` in this sub-package's BUILD file.
Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@

from tensorflow_addons.activations.gelu import gelu
from tensorflow_addons.activations.hardshrink import hardshrink
from tensorflow_addons.activations.lisht import lisht
from tensorflow_addons.activations.sparsemax import sparsemax
from tensorflow_addons.activations.tanhshrink import tanhshrink
49 changes: 49 additions & 0 deletions tensorflow_addons/activations/lisht.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils import keras_utils
from tensorflow_addons.utils.resource_loader import get_path_to_datafile

_activation_ops_so = tf.load_op_library(
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))


@keras_utils.register_keras_custom_object
@tf.function
def lisht(x):
"""LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function.

Computes linearly scaled hyperbolic tangent (LiSHT): `x * tanh(x)`

See [LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function for Neural Networks](https://arxiv.org/abs/1901.05894).

Args:
x: A `Tensor`. Must be one of the following types:
`float16`, `float32`, `float64`.
Returns:
A `Tensor`. Has the same type as `x`.
"""
x = tf.convert_to_tensor(x)
return _activation_ops_so.addons_lisht(x)


@tf.RegisterGradient("Addons>Lisht")
def _lisht_grad(op, grad):
return _activation_ops_so.addons_lisht_grad(grad, op.inputs[0])
73 changes: 73 additions & 0 deletions tensorflow_addons/activations/lisht_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl.testing import parameterized

import numpy as np
import tensorflow as tf
from tensorflow_addons.activations import lisht
from tensorflow_addons.utils import test_utils


@test_utils.run_all_in_graph_and_eager_modes
class LishtTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
("float64", np.float64))
def test_lisht(self, dtype):
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
expected_result = tf.constant(
[1.9280552, 0.7615942, 0.0, 0.7615942, 1.9280552], dtype=dtype)
self.assertAllCloseAccordingToType(lisht(x), expected_result)

@parameterized.named_parameters(("float32", np.float32),
("float64", np.float64))
def test_theoretical_gradients(self, dtype):
# Only test theoretical gradients for float32 and float64
# because of the instability of float16 while computing jacobian
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)

theoretical, numerical = tf.test.compute_gradient(lisht, [x])
self.assertAllCloseAccordingToType(
theoretical, numerical, rtol=5e-4, atol=5e-4)

def test_unknown_shape(self):
fn = lisht.get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32))

for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]:
x = tf.ones(shape=shape, dtype=tf.float32)
self.assertAllClose(fn(x), lisht(x))

def test_serialization(self):
config = tf.keras.activations.serialize(lisht)
fn = tf.keras.activations.deserialize(config)
self.assertEqual(fn, lisht)

def test_serialization_with_layers(self):
layer = tf.keras.layers.Dense(3, activation=lisht)
config = tf.keras.layers.serialize(layer)
deserialized_layer = tf.keras.layers.deserialize(config)
self.assertEqual(deserialized_layer.__class__.__name__,
layer.__class__.__name__)
self.assertEqual(deserialized_layer.activation.__name__, "lisht")


if __name__ == "__main__":
tf.test.main()
26 changes: 26 additions & 0 deletions tensorflow_addons/custom_ops/activations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,28 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "lisht_op_gpu",
srcs = [
"cc/kernels/lisht_op.h",
"cc/kernels/lisht_op_gpu.cu.cc",
],
copts = if_cuda_is_configured([
"-DGOOGLE_CUDA=1",
"-x cuda",
"-nvcc_options=relaxed-constexpr",
"-nvcc_options=ftz=true",
]),
deps = [
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_libs",
"@local_config_cuda//cuda:cuda_headers",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe there should be a bazel function wrap the cuda config

]),
alwayslink = 1,
)

cc_library(
name = "tanhshrink_op_gpu",
srcs = [
Expand Down Expand Up @@ -78,10 +100,13 @@ cc_binary(
"cc/kernels/gelu_op.h",
"cc/kernels/hardshrink_op.cc",
"cc/kernels/hardshrink_op.h",
"cc/kernels/lisht_op.cc",
"cc/kernels/lisht_op.h",
"cc/kernels/tanhshrink_op.cc",
"cc/kernels/tanhshrink_op.h",
"cc/ops/gelu_op.cc",
"cc/ops/hardshrink_op.cc",
"cc/ops/lisht_op.cc",
"cc/ops/tanhshrink_op.cc",
],
copts = [
Expand All @@ -96,6 +121,7 @@ cc_binary(
] + if_cuda_is_configured([
":gelu_op_gpu",
":hardshrink_op_gpu",
":lisht_op_gpu",
":tanhshrink_op_gpu",
]),
)
79 changes: 79 additions & 0 deletions tensorflow_addons/custom_ops/activations/cc/kernels/lisht_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#define EIGEN_USE_THREADS

#include "tensorflow_addons/custom_ops/activations/cc/kernels/lisht_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {
namespace addons {

using CPUDevice = Eigen::ThreadPoolDevice;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typedef?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity. Is there any difference between these two? I do suppose our ops are compiled with c++11 standard.

Quote from cppreference.

There is no difference between a type alias declaration and typedef declaration.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. Their difference is: the alias declaration is compatible with templates, whereas the C style typedef is not. My first thought is for consistency with tf core.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


#define REGISTER_LISHT_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Addons>Lisht").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
LishtOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("Addons>LishtGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
LishtGradOp<CPUDevice, type>);

// Lisht only makes sense with floating points.
TF_CALL_GPU_NUMBER_TYPES(REGISTER_LISHT_KERNELS);
#undef REGISTER_LISHT_KERNELS

#if GOOGLE_CUDA

using GPUDevice = Eigen::GpuDevice;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void Lisht<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor activations); \
extern template struct Lisht<GPUDevice, T>; \
\
template <> \
void LishtGrad<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor backprops); \
extern template struct LishtGrad<GPUDevice, T>;

TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
#undef DECLARE_GPU_SPEC
} // namespace functor

// Registration of the GPU implementations.
#define REGISTER_LISHT_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Addons>Lisht").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
LishtOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("Addons>LishtGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
LishtGradOp<GPUDevice, type>);

TF_CALL_GPU_NUMBER_TYPES(REGISTER_LISHT_GPU_KERNELS);
#undef REGISTER_LISHT_GPU_KERNELS

#endif // GOOGLE_CUDA

} // namespace addons
} // namespace tensorflow
Loading