Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add rrelu kernel #573

Merged
merged 14 commits into from
Oct 25, 2019
Prev Previous commit
Next Next commit
format code
  • Loading branch information
fsx950223 committed Oct 20, 2019
commit 3db650297997c17ac2b44a0c2bf7bfe7d0cdfe04
2 changes: 1 addition & 1 deletion tensorflow_addons/activations/activations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
class ActivationsTest(tf.test.TestCase):

ALL_ACTIVATIONS = [
"gelu", "hardshrink", "lisht", "softshrink", "sparsemax", "tanhshrink"
"gelu", "hardshrink", "lisht", "softshrink", "sparsemax","rrelu", "tanhshrink"
]

def test_serialization(self):
Expand Down
5 changes: 4 additions & 1 deletion tensorflow_addons/activations/rrelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,7 @@ def rrelu(x, lower=0.125, upper=0.3333333333333333, training=None):
@tf.RegisterGradient("Addons>Rrelu")
def _rrelu_grad(op, grad):
return _activation_ops_so.addons_rrelu_grad(grad, op.inputs[0],
op.outputs[1],op.get_attr("lower"),op.get_attr("upper"),op.get_attr("training"))
op.outputs[1],
op.get_attr("lower"),
op.get_attr("upper"),
op.get_attr("training"))
6 changes: 3 additions & 3 deletions tensorflow_addons/activations/rrelu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
import tensorflow as tf
from tensorflow_addons.activations import rrelu
from tensorflow_addons.utils import test_utils
import random
import random


SEED=111111
SEED = 111111
tf.random.set_seed(SEED)
random.seed(SEED)


@test_utils.run_all_in_graph_and_eager_modes
class RreluTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("float16", np.float16),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
#undef DECLARE_GPU_SPEC
} // namespace functor

#define REGISTER_RRELU_GPU_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
#define REGISTER_RRELU_GPU_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
Name("Addons>Rrelu").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
RreluOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
RreluOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("Addons>RreluGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
RreluGradOp<GPUDevice, T>);

Expand Down
37 changes: 17 additions & 20 deletions tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ struct Rrelu {
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
T lower, T upper, bool training,
typename TTypes<T>::Tensor activations,
typename TTypes<T>::Tensor alpha){
typename TTypes<T>::Tensor alpha) {
if (training) {
alpha.device(d) = alpha.constant(lower) +
(alpha.random() + alpha.constant(static_cast<T>(1))) *
alpha.constant((upper - lower) / static_cast<T>(2));
activations.device(d) = (features >= static_cast<T>(0))
.select(features, alpha * features);
(alpha.random() + alpha.constant(static_cast<T>(1))) *
alpha.constant((upper - lower) / static_cast<T>(2));
activations.device(d) =
(features >= static_cast<T>(0)).select(features, alpha * features);
} else {
activations.device(d) =
(features >= static_cast<T>(0))
Expand All @@ -54,7 +54,7 @@ struct RreluGrad {
void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
typename TTypes<T>::ConstTensor features,
typename TTypes<T>::ConstTensor alpha, T lower, T upper,
bool training, typename TTypes<T>::Tensor backprops){
bool training, typename TTypes<T>::Tensor backprops) {
if (training) {
backprops.device(d) =
gradients *
Expand Down Expand Up @@ -82,12 +82,10 @@ class RreluOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("training", &training_));
lower_ = static_cast<T>(lower);
OP_REQUIRES(context, lower_ >= static_cast<T>(0),
errors::InvalidArgument("Need lower >= 0, got ",
lower_));
errors::InvalidArgument("Need lower >= 0, got ", lower_));
upper_ = static_cast<T>(upper);
OP_REQUIRES(context, upper_ < static_cast<T>(1),
errors::InvalidArgument("Need upper < 1, got ",
upper_));
errors::InvalidArgument("Need upper < 1, got ", upper_));
OP_REQUIRES(
context, lower_ <= upper_,
errors::InvalidArgument("lower must be less than or equal to upper."));
Expand All @@ -101,9 +99,9 @@ class RreluOp : public OpKernel {
OP_REQUIRES_OK(context, context->allocate_output(1, input_tensor.shape(),
&alpha_tensor));
// functor::Rrelu<Device, T> functor;
functor::Rrelu<Device, T>()(context->eigen_device<Device>(), input_tensor.flat<T>(), lower_,
upper_, training_, output_tensor->flat<T>(),
alpha_tensor->flat<T>());
functor::Rrelu<Device, T>()(
context->eigen_device<Device>(), input_tensor.flat<T>(), lower_, upper_,
training_, output_tensor->flat<T>(), alpha_tensor->flat<T>());
}

private:
Expand All @@ -122,12 +120,10 @@ class RreluGradOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("training", &training_));
lower_ = static_cast<T>(lower);
OP_REQUIRES(context, lower_ >= static_cast<T>(0),
errors::InvalidArgument("Need lower >= 0, got ",
lower_));
errors::InvalidArgument("Need lower >= 0, got ", lower_));
upper_ = static_cast<T>(upper);
OP_REQUIRES(context, upper_ < static_cast<T>(1),
errors::InvalidArgument("Need upper < 1, got ",
upper_));
errors::InvalidArgument("Need upper < 1, got ", upper_));
OP_REQUIRES(
context, lower_ <= upper_,
errors::InvalidArgument("lower must be less than or equal to upper."));
Expand All @@ -140,9 +136,10 @@ class RreluGradOp : public OpKernel {
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
// functor::RreluGrad<Device, T> functor;
functor::RreluGrad<Device, T>()(context->eigen_device<Device>(), gradients.flat<T>(),
input_tensor.flat<T>(), alpha_tensor.flat<T>(), lower_, upper_,
training_, output_tensor->flat<T>());
functor::RreluGrad<Device, T>()(context->eigen_device<Device>(),
gradients.flat<T>(), input_tensor.flat<T>(),
alpha_tensor.flat<T>(), lower_, upper_,
training_, output_tensor->flat<T>());
}

private:
Expand Down