Skip to content

Commit

Permalink
add rrelu kernel (#573)
Browse files Browse the repository at this point in the history
* add rrelu kernel

* format code

* change test case

* simplify kernel implemention

* fix test case

* refact test case

* remove forward training test

* format code

* modify test case

* add forget files

* update kernel random implementation

* fix rrelu test bug

* add TODO

* fix comment
  • Loading branch information
fsx950223 authored and WindQAQ committed Oct 25, 2019
1 parent 7801294 commit 8a49f91
Show file tree
Hide file tree
Showing 11 changed files with 462 additions and 1 deletion.
14 changes: 14 additions & 0 deletions tensorflow_addons/activations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ py_library(
"gelu.py",
"hardshrink.py",
"lisht.py",
"rrelu.py",
"softshrink.py",
"sparsemax.py",
"tanhshrink.py",
Expand Down Expand Up @@ -110,3 +111,16 @@ py_test(
":activations",
],
)

py_test(
name = "rrelu_test",
size = "small",
srcs = [
"rrelu_test.py",
],
main = "rrelu_test.py",
srcs_version = "PY2AND3",
deps = [
":activations",
],
)
2 changes: 2 additions & 0 deletions tensorflow_addons/activations/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
| softshrink| @WindQAQ | [email protected] |
| sparsemax | @AndreasMadsen | [email protected] |
| tanhshrink| @fsx950223 | [email protected] |
| rrelu | @fsx950223 | [email protected] |

## Contents
| Submodule | Activation | Reference |
Expand All @@ -19,6 +20,7 @@
| softshrink| softshrink | |
| sparsemax | sparsemax | https://arxiv.org/abs/1602.02068 |
| tanhshrink| tanhshrink | |
| rrelu | rrelu | https://arxiv.org/abs/1505.00853 |


## Contribution Guidelines
Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@
from tensorflow_addons.activations.hardshrink import hardshrink
from tensorflow_addons.activations.lisht import lisht
from tensorflow_addons.activations.softshrink import softshrink
from tensorflow_addons.activations.rrelu import rrelu
from tensorflow_addons.activations.sparsemax import sparsemax
from tensorflow_addons.activations.tanhshrink import tanhshrink
3 changes: 2 additions & 1 deletion tensorflow_addons/activations/activations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
class ActivationsTest(tf.test.TestCase):

ALL_ACTIVATIONS = [
"gelu", "hardshrink", "lisht", "softshrink", "sparsemax", "tanhshrink"
"gelu", "hardshrink", "lisht", "softshrink", "sparsemax", "rrelu",
"tanhshrink"
]

def test_serialization(self):
Expand Down
65 changes: 65 additions & 0 deletions tensorflow_addons/activations/rrelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils import keras_utils
from tensorflow_addons.utils.resource_loader import get_path_to_datafile

_activation_ops_so = tf.load_op_library(
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))


@keras_utils.register_keras_custom_object
@tf.function
def rrelu(x, lower=0.125, upper=0.3333333333333333, training=None, seed=None):
"""rrelu function.
Computes rrelu function:
`x if x > 0 else random(lower, upper) * x` or
`x if x > 0 else x * (lower + upper) / 2`
depending on whether training is enabled.
See [Empirical Evaluation of Rectified Activations in Convolutional Network](https://arxiv.org/abs/1505.00853).
Args:
x: A `Tensor`. Must be one of the following types:
`float16`, `float32`, `float64`.
lower: `float`, lower bound for random alpha.
upper: `float`, upper bound for random alpha.
training: `bool`, indicating whether the `call`
is meant for training or inference.
seed: `int`, this sets the operation-level seed.
Returns:
result: A `Tensor`. Has the same type as `x`.
"""
x = tf.convert_to_tensor(x)
if training is None:
training = tf.keras.backend.learning_phase()
training = bool(tf.keras.backend.get_value(training))
# TODO: get rid of v1 API
seed1, seed2 = tf.compat.v1.random.get_seed(seed)
result, _ = _activation_ops_so.addons_rrelu(x, lower, upper, training,
seed1, seed2)
return result


@tf.RegisterGradient("Addons>Rrelu")
def _rrelu_grad(op, *grad):
return _activation_ops_so.addons_rrelu_grad(grad[0], op.inputs[0],
op.outputs[1])
78 changes: 78 additions & 0 deletions tensorflow_addons/activations/rrelu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl.testing import parameterized

import numpy as np
import tensorflow as tf
from tensorflow_addons.activations import rrelu
from tensorflow_addons.utils import test_utils


def _ref_rrelu(x, lower, upper):
return tf.where(x >= 0, x, (lower + upper) * x / 2)


@test_utils.run_all_in_graph_and_eager_modes
class RreluTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
("float64", np.float64))
@tf.function
def test_rrelu(self, dtype):
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
lower = 0.1
upper = 0.2
result = rrelu(x, lower, upper, training=False)
expect_result = _ref_rrelu(x, lower, upper)
self.assertAllCloseAccordingToType(result, expect_result)

@parameterized.named_parameters(("float32", np.float32),
("float64", np.float64))
def test_theoretical_gradients(self, dtype):
x = tf.constant([-2.0, -1.0, -0.1, 0.1, 1.0, 2.0], dtype=dtype)
lower = 0.1
upper = 0.2
for training in [True, False]:
with self.subTest(training=training):
theoretical, numerical = tf.test.compute_gradient(
lambda x: rrelu(
x, lower, upper, training=training, seed=111111), [x])
# TODO: investigate the difference between CPU and GPU
if training is True and tf.test.is_gpu_available() is False:
numerical = [[[0.134971, 0., 0., 0., 0., 0.],
[0., 0.15648358, 0., 0., 0., 0.],
[0., 0., 0.18776372, 0., 0., 0.],
[0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 1.]]]
self.assertAllCloseAccordingToType(
theoretical, numerical, rtol=5e-4, atol=5e-4)

def test_unknown_shape(self):
fn = rrelu.get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32))

for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]:
x = tf.ones(shape=shape, dtype=tf.float32)
self.assertAllClose(fn(x), rrelu(x))


if __name__ == "__main__":
tf.test.main()
5 changes: 5 additions & 0 deletions tensorflow_addons/custom_ops/activations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@ custom_op_library(
"cc/kernels/hardshrink_op.h",
"cc/kernels/lisht_op.cc",
"cc/kernels/lisht_op.h",
"cc/kernels/rrelu_op.cc",
"cc/kernels/rrelu_op.h",
"cc/kernels/softshrink_op.cc",
"cc/kernels/softshrink_op.h",
"cc/kernels/tanhshrink_op.cc",
"cc/kernels/tanhshrink_op.h",
"cc/ops/gelu_op.cc",
"cc/ops/hardshrink_op.cc",
"cc/ops/lisht_op.cc",
"cc/ops/rrelu_op.cc",
"cc/ops/softshrink_op.cc",
"cc/ops/tanhshrink_op.cc",
],
Expand All @@ -30,6 +33,8 @@ custom_op_library(
"cc/kernels/hardshrink_op_gpu.cu.cc",
"cc/kernels/lisht_op.h",
"cc/kernels/lisht_op_gpu.cu.cc",
"cc/kernels/rrelu_op.h",
"cc/kernels/rrelu_op_gpu.cu.cc",
"cc/kernels/softshrink_op.h",
"cc/kernels/softshrink_op_gpu.cu.cc",
"cc/kernels/tanhshrink_op.h",
Expand Down
78 changes: 78 additions & 0 deletions tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#define EIGEN_USE_THREADS

#include "tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"

namespace tensorflow {
namespace addons {

using CPUDevice = Eigen::ThreadPoolDevice;

#define REGISTER_RRELU_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
Name("Addons>Rrelu").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RreluOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("Addons>RreluGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RreluGradOp<CPUDevice, T>);

// Rrelu only makes sense with floating points.
TF_CALL_GPU_NUMBER_TYPES(REGISTER_RRELU_KERNELS);
#undef REGISTER_RRELU_KERNELS

#if GOOGLE_CUDA

using GPUDevice = Eigen::GpuDevice;

namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void Rrelu<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor features, T lower, \
T upper, bool training, typename TTypes<T>::Tensor activations, \
typename TTypes<T>::Tensor alpha, \
typename random::SimplePhilox& random); \
extern template struct Rrelu<GPUDevice, T>; \
\
template <> \
void RreluGrad<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::ConstTensor alpha, \
typename TTypes<T>::Tensor backprops); \
extern template struct RreluGrad<GPUDevice, T>;

TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
#undef DECLARE_GPU_SPEC
} // namespace functor

#define REGISTER_RRELU_GPU_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
Name("Addons>Rrelu").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
RreluOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("Addons>RreluGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
RreluGradOp<GPUDevice, T>);

TF_CALL_GPU_NUMBER_TYPES(REGISTER_RRELU_GPU_KERNELS);
#undef REGISTER_RRELU_GPU_KERNELS

#endif // GOOGLE_CUDA
} // namespace addons
} // namespace tensorflow
Loading

0 comments on commit 8a49f91

Please sign in to comment.