diff --git a/tensorflow_addons/activations/BUILD b/tensorflow_addons/activations/BUILD index af015cef5f..a67b39be21 100644 --- a/tensorflow_addons/activations/BUILD +++ b/tensorflow_addons/activations/BUILD @@ -9,6 +9,7 @@ py_library( "gelu.py", "hardshrink.py", "lisht.py", + "rrelu.py", "softshrink.py", "sparsemax.py", "tanhshrink.py", @@ -110,3 +111,16 @@ py_test( ":activations", ], ) + +py_test( + name = "rrelu_test", + size = "small", + srcs = [ + "rrelu_test.py", + ], + main = "rrelu_test.py", + srcs_version = "PY2AND3", + deps = [ + ":activations", + ], +) diff --git a/tensorflow_addons/activations/README.md b/tensorflow_addons/activations/README.md index b8f233fa3f..3a0cfb1323 100644 --- a/tensorflow_addons/activations/README.md +++ b/tensorflow_addons/activations/README.md @@ -9,6 +9,7 @@ | softshrink| @WindQAQ | windqaq@gmail.com | | sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | | tanhshrink| @fsx950223 | fsx950223@gmail.com | +| rrelu | @fsx950223 | fsx950223@gmail.com | ## Contents | Submodule | Activation | Reference | @@ -19,6 +20,7 @@ | softshrink| softshrink | | | sparsemax | sparsemax | https://arxiv.org/abs/1602.02068 | | tanhshrink| tanhshrink | | +| rrelu | rrelu | https://arxiv.org/abs/1505.00853 | ## Contribution Guidelines diff --git a/tensorflow_addons/activations/__init__.py b/tensorflow_addons/activations/__init__.py index ba9d6a3738..af8dc4ead8 100644 --- a/tensorflow_addons/activations/__init__.py +++ b/tensorflow_addons/activations/__init__.py @@ -22,5 +22,6 @@ from tensorflow_addons.activations.hardshrink import hardshrink from tensorflow_addons.activations.lisht import lisht from tensorflow_addons.activations.softshrink import softshrink +from tensorflow_addons.activations.rrelu import rrelu from tensorflow_addons.activations.sparsemax import sparsemax from tensorflow_addons.activations.tanhshrink import tanhshrink diff --git a/tensorflow_addons/activations/activations_test.py b/tensorflow_addons/activations/activations_test.py index d685e7d2ca..58b946577d 100644 --- a/tensorflow_addons/activations/activations_test.py +++ b/tensorflow_addons/activations/activations_test.py @@ -26,7 +26,8 @@ class ActivationsTest(tf.test.TestCase): ALL_ACTIVATIONS = [ - "gelu", "hardshrink", "lisht", "softshrink", "sparsemax", "tanhshrink" + "gelu", "hardshrink", "lisht", "softshrink", "sparsemax", "rrelu", + "tanhshrink" ] def test_serialization(self): diff --git a/tensorflow_addons/activations/rrelu.py b/tensorflow_addons/activations/rrelu.py new file mode 100644 index 0000000000..07a0bbfa74 --- /dev/null +++ b/tensorflow_addons/activations/rrelu.py @@ -0,0 +1,65 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.utils import keras_utils +from tensorflow_addons.utils.resource_loader import get_path_to_datafile + +_activation_ops_so = tf.load_op_library( + get_path_to_datafile("custom_ops/activations/_activation_ops.so")) + + +@keras_utils.register_keras_custom_object +@tf.function +def rrelu(x, lower=0.125, upper=0.3333333333333333, training=None, seed=None): + """rrelu function. + + Computes rrelu function: + `x if x > 0 else random(lower, upper) * x` or + `x if x > 0 else x * (lower + upper) / 2` + depending on whether training is enabled. + + See [Empirical Evaluation of Rectified Activations in Convolutional Network](https://arxiv.org/abs/1505.00853). + + Args: + x: A `Tensor`. Must be one of the following types: + `float16`, `float32`, `float64`. + lower: `float`, lower bound for random alpha. + upper: `float`, upper bound for random alpha. + training: `bool`, indicating whether the `call` + is meant for training or inference. + seed: `int`, this sets the operation-level seed. + Returns: + result: A `Tensor`. Has the same type as `x`. + """ + x = tf.convert_to_tensor(x) + if training is None: + training = tf.keras.backend.learning_phase() + training = bool(tf.keras.backend.get_value(training)) + # TODO: get rid of v1 API + seed1, seed2 = tf.compat.v1.random.get_seed(seed) + result, _ = _activation_ops_so.addons_rrelu(x, lower, upper, training, + seed1, seed2) + return result + + +@tf.RegisterGradient("Addons>Rrelu") +def _rrelu_grad(op, *grad): + return _activation_ops_so.addons_rrelu_grad(grad[0], op.inputs[0], + op.outputs[1]) diff --git a/tensorflow_addons/activations/rrelu_test.py b/tensorflow_addons/activations/rrelu_test.py new file mode 100644 index 0000000000..5c5b283239 --- /dev/null +++ b/tensorflow_addons/activations/rrelu_test.py @@ -0,0 +1,78 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +import numpy as np +import tensorflow as tf +from tensorflow_addons.activations import rrelu +from tensorflow_addons.utils import test_utils + + +def _ref_rrelu(x, lower, upper): + return tf.where(x >= 0, x, (lower + upper) * x / 2) + + +@test_utils.run_all_in_graph_and_eager_modes +class RreluTest(tf.test.TestCase, parameterized.TestCase): + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) + @tf.function + def test_rrelu(self, dtype): + x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) + lower = 0.1 + upper = 0.2 + result = rrelu(x, lower, upper, training=False) + expect_result = _ref_rrelu(x, lower, upper) + self.assertAllCloseAccordingToType(result, expect_result) + + @parameterized.named_parameters(("float32", np.float32), + ("float64", np.float64)) + def test_theoretical_gradients(self, dtype): + x = tf.constant([-2.0, -1.0, -0.1, 0.1, 1.0, 2.0], dtype=dtype) + lower = 0.1 + upper = 0.2 + for training in [True, False]: + with self.subTest(training=training): + theoretical, numerical = tf.test.compute_gradient( + lambda x: rrelu( + x, lower, upper, training=training, seed=111111), [x]) + # TODO: investigate the difference between CPU and GPU + if training is True and tf.test.is_gpu_available() is False: + numerical = [[[0.134971, 0., 0., 0., 0., 0.], + [0., 0.15648358, 0., 0., 0., 0.], + [0., 0., 0.18776372, 0., 0., 0.], + [0., 0., 0., 1., 0., 0.], + [0., 0., 0., 0., 1., 0.], + [0., 0., 0., 0., 0., 1.]]] + self.assertAllCloseAccordingToType( + theoretical, numerical, rtol=5e-4, atol=5e-4) + + def test_unknown_shape(self): + fn = rrelu.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32)) + + for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]: + x = tf.ones(shape=shape, dtype=tf.float32) + self.assertAllClose(fn(x), rrelu(x)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_addons/custom_ops/activations/BUILD b/tensorflow_addons/custom_ops/activations/BUILD index f0b7624b4a..f12442cefe 100644 --- a/tensorflow_addons/custom_ops/activations/BUILD +++ b/tensorflow_addons/custom_ops/activations/BUILD @@ -13,6 +13,8 @@ custom_op_library( "cc/kernels/hardshrink_op.h", "cc/kernels/lisht_op.cc", "cc/kernels/lisht_op.h", + "cc/kernels/rrelu_op.cc", + "cc/kernels/rrelu_op.h", "cc/kernels/softshrink_op.cc", "cc/kernels/softshrink_op.h", "cc/kernels/tanhshrink_op.cc", @@ -20,6 +22,7 @@ custom_op_library( "cc/ops/gelu_op.cc", "cc/ops/hardshrink_op.cc", "cc/ops/lisht_op.cc", + "cc/ops/rrelu_op.cc", "cc/ops/softshrink_op.cc", "cc/ops/tanhshrink_op.cc", ], @@ -30,6 +33,8 @@ custom_op_library( "cc/kernels/hardshrink_op_gpu.cu.cc", "cc/kernels/lisht_op.h", "cc/kernels/lisht_op_gpu.cu.cc", + "cc/kernels/rrelu_op.h", + "cc/kernels/rrelu_op_gpu.cu.cc", "cc/kernels/softshrink_op.h", "cc/kernels/softshrink_op_gpu.cu.cc", "cc/kernels/tanhshrink_op.h", diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op.cc new file mode 100644 index 0000000000..a3b5822a95 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op.cc @@ -0,0 +1,78 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { +namespace addons { + +using CPUDevice = Eigen::ThreadPoolDevice; + +#define REGISTER_RRELU_KERNELS(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>Rrelu").Device(DEVICE_CPU).TypeConstraint("T"), \ + RreluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>RreluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + RreluGradOp); + +// Rrelu only makes sense with floating points. +TF_CALL_GPU_NUMBER_TYPES(REGISTER_RRELU_KERNELS); +#undef REGISTER_RRELU_KERNELS + +#if GOOGLE_CUDA + +using GPUDevice = Eigen::GpuDevice; + +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void Rrelu::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor features, T lower, \ + T upper, bool training, typename TTypes::Tensor activations, \ + typename TTypes::Tensor alpha, \ + typename random::SimplePhilox& random); \ + extern template struct Rrelu; \ + \ + template <> \ + void RreluGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, \ + typename TTypes::ConstTensor alpha, \ + typename TTypes::Tensor backprops); \ + extern template struct RreluGrad; + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +#undef DECLARE_GPU_SPEC +} // namespace functor + +#define REGISTER_RRELU_GPU_KERNELS(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>Rrelu").Device(DEVICE_GPU).TypeConstraint("T"), \ + RreluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>RreluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + RreluGradOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_RRELU_GPU_KERNELS); +#undef REGISTER_RRELU_GPU_KERNELS + +#endif // GOOGLE_CUDA +} // namespace addons +} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op.h new file mode 100644 index 0000000000..620fece874 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op.h @@ -0,0 +1,136 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_RRELU_OP_H_ +#define TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_RRELU_OP_H_ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/util/guarded_philox_random.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace addons { + +namespace functor { + +template +struct Rrelu { + void operator()(const Device& d, typename TTypes::ConstTensor features, + T lower, T upper, bool training, + typename TTypes::Tensor activations, + typename TTypes::Tensor alpha, + typename random::SimplePhilox& random) { + if (training) { + T storage[alpha.size()]; + typename TTypes::Tensor alpha_tensor(storage, alpha.size()); + for (int i = 0; i < alpha_tensor.size(); i++) { + alpha_tensor(i) = + lower + static_cast(random.RandFloat()) * (upper - lower); + } + alpha.device(d) = alpha_tensor; + } else { + alpha.device(d) = features.constant((lower + upper) / static_cast(2)); + } + activations.device(d) = + (features >= static_cast(0)).select(features, alpha * features); + } +}; + +template +struct RreluGrad { + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, + typename TTypes::ConstTensor alpha, + typename TTypes::Tensor backprops) { + backprops.device(d) = + gradients * + (features >= static_cast(0)) + .select(features.constant(static_cast(1)), alpha); + } +}; + +} // namespace functor + +template +class RreluOp : public OpKernel { + public: + explicit RreluOp(OpKernelConstruction* context) : OpKernel(context) { + float lower, upper; + OP_REQUIRES_OK(context, context->GetAttr("lower", &lower)); + OP_REQUIRES_OK(context, context->GetAttr("upper", &upper)); + OP_REQUIRES_OK(context, context->GetAttr("training", &training_)); + OP_REQUIRES_OK(context, generator_.Init(context)); + lower_ = static_cast(lower); + OP_REQUIRES(context, lower_ >= static_cast(0), + errors::InvalidArgument("Need lower >= 0, got ", lower_)); + upper_ = static_cast(upper); + OP_REQUIRES(context, upper_ < static_cast(1), + errors::InvalidArgument("Need upper < 1, got ", upper_)); + OP_REQUIRES( + context, lower_ <= upper_, + errors::InvalidArgument("lower must be less than or equal to upper.")); + } + void Compute(OpKernelContext* context) override { + const Tensor& input_tensor = context->input(0); + Tensor* output_tensor = nullptr; + Tensor* alpha_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), + &output_tensor)); + OP_REQUIRES_OK(context, context->allocate_output(1, input_tensor.shape(), + &alpha_tensor)); + auto local_gen = generator_.ReserveSamples32(2); + random::SimplePhilox random(&local_gen); + functor::Rrelu()( + context->eigen_device(), input_tensor.flat(), lower_, upper_, + training_, output_tensor->flat(), alpha_tensor->flat(), random); + } + + private: + T lower_; + T upper_; + bool training_; + GuardedPhiloxRandom generator_; +}; + +template +class RreluGradOp : public OpKernel { + public: + explicit RreluGradOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& gradients = context->input(0); + const Tensor& input_tensor = context->input(1); + const Tensor& alpha_tensor = context->input(2); + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), + &output_tensor)); + functor::RreluGrad()(context->eigen_device(), + gradients.flat(), input_tensor.flat(), + alpha_tensor.flat(), + output_tensor->flat()); + } +}; + +} // namespace addons +} // namespace tensorflow + +#undef EIGEN_USE_THREADS + +#endif // TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_RRELU_OP_H_ diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op_gpu.cu.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op_gpu.cu.cc new file mode 100644 index 0000000000..f16673457b --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op_gpu.cu.cc @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op.h" +#include "tensorflow/core/framework/register_types.h" +#include "third_party/eigen3/Eigen/Core" + +namespace tensorflow { +namespace addons { + +using GPUDevice = Eigen::GpuDevice; + +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Rrelu; \ + template struct functor::RreluGrad; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); + +} // namespace addons +} // namespace tensorflow + +#endif // GOOGLE_CUDA \ No newline at end of file diff --git a/tensorflow_addons/custom_ops/activations/cc/ops/rrelu_op.cc b/tensorflow_addons/custom_ops/activations/cc/ops/rrelu_op.cc new file mode 100644 index 0000000000..751617b736 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/ops/rrelu_op.cc @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { +namespace addons { + +REGISTER_OP("Addons>Rrelu") + .Input("features: T") + .Output("activations: T") + .Output("alpha: T") + .Attr("T: {half, float, double}") + .Attr("lower: float") + .Attr("upper: float") + .Attr("training: bool") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .SetShapeFn(shape_inference::UnchangedShape); + +REGISTER_OP("Addons>RreluGrad") + .Input("gradients: T") + .Input("features: T") + .Input("alpha: T") + .Output("backprops: T") + .Attr("T: {half, float, double}") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn); +} // namespace addons +} // namespace tensorflow \ No newline at end of file