Skip to content

Commit

Permalink
format code
Browse files Browse the repository at this point in the history
  • Loading branch information
fsx950223 committed Oct 7, 2019
1 parent 8784273 commit 2459e94
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 29 deletions.
2 changes: 1 addition & 1 deletion tensorflow_addons/activations/activations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
class ActivationsTest(tf.test.TestCase):

ALL_ACTIVATIONS = [
"gelu", "hardshrink", "lisht", "sparsemax", "tanhshrink"
"gelu", "hardshrink", "lisht", "rrelu", "sparsemax", "tanhshrink"
]

def test_serialization(self):
Expand Down
5 changes: 4 additions & 1 deletion tensorflow_addons/activations/rrelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,7 @@ def rrelu(x, lower=0.125, upper=0.3333333333333333, training=None):
@tf.RegisterGradient("Addons>Rrelu")
def _rrelu_grad(op, grad):
return _activation_ops_so.addons_rrelu_grad(grad, op.inputs[0],
op.outputs[1],op.get_attr("lower"),op.get_attr("upper"),op.get_attr("training"))
op.outputs[1],
op.get_attr("lower"),
op.get_attr("upper"),
op.get_attr("training"))
6 changes: 3 additions & 3 deletions tensorflow_addons/activations/rrelu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
import tensorflow as tf
from tensorflow_addons.activations import rrelu
from tensorflow_addons.utils import test_utils
import random
import random


SEED=111111
SEED = 111111
tf.random.set_seed(SEED)
random.seed(SEED)


@test_utils.run_all_in_graph_and_eager_modes
class RreluTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("float16", np.float16),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
#undef DECLARE_GPU_SPEC
} // namespace functor

#define REGISTER_RRELU_GPU_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
#define REGISTER_RRELU_GPU_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
Name("Addons>Rrelu").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
RreluOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
RreluOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("Addons>RreluGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
RreluGradOp<GPUDevice, T>);

Expand Down
37 changes: 17 additions & 20 deletions tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ struct Rrelu {
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
T lower, T upper, bool training,
typename TTypes<T>::Tensor activations,
typename TTypes<T>::Tensor alpha){
typename TTypes<T>::Tensor alpha) {
if (training) {
alpha.device(d) = alpha.constant(lower) +
(alpha.random() + alpha.constant(static_cast<T>(1))) *
alpha.constant((upper - lower) / static_cast<T>(2));
activations.device(d) = (features >= static_cast<T>(0))
.select(features, alpha * features);
(alpha.random() + alpha.constant(static_cast<T>(1))) *
alpha.constant((upper - lower) / static_cast<T>(2));
activations.device(d) =
(features >= static_cast<T>(0)).select(features, alpha * features);
} else {
activations.device(d) =
(features >= static_cast<T>(0))
Expand All @@ -54,7 +54,7 @@ struct RreluGrad {
void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
typename TTypes<T>::ConstTensor features,
typename TTypes<T>::ConstTensor alpha, T lower, T upper,
bool training, typename TTypes<T>::Tensor backprops){
bool training, typename TTypes<T>::Tensor backprops) {
if (training) {
backprops.device(d) =
gradients *
Expand Down Expand Up @@ -82,12 +82,10 @@ class RreluOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("training", &training_));
lower_ = static_cast<T>(lower);
OP_REQUIRES(context, lower_ >= static_cast<T>(0),
errors::InvalidArgument("Need lower >= 0, got ",
lower_));
errors::InvalidArgument("Need lower >= 0, got ", lower_));
upper_ = static_cast<T>(upper);
OP_REQUIRES(context, upper_ < static_cast<T>(1),
errors::InvalidArgument("Need upper < 1, got ",
upper_));
errors::InvalidArgument("Need upper < 1, got ", upper_));
OP_REQUIRES(
context, lower_ <= upper_,
errors::InvalidArgument("lower must be less than or equal to upper."));
Expand All @@ -101,9 +99,9 @@ class RreluOp : public OpKernel {
OP_REQUIRES_OK(context, context->allocate_output(1, input_tensor.shape(),
&alpha_tensor));
// functor::Rrelu<Device, T> functor;
functor::Rrelu<Device, T>()(context->eigen_device<Device>(), input_tensor.flat<T>(), lower_,
upper_, training_, output_tensor->flat<T>(),
alpha_tensor->flat<T>());
functor::Rrelu<Device, T>()(
context->eigen_device<Device>(), input_tensor.flat<T>(), lower_, upper_,
training_, output_tensor->flat<T>(), alpha_tensor->flat<T>());
}

private:
Expand All @@ -122,12 +120,10 @@ class RreluGradOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("training", &training_));
lower_ = static_cast<T>(lower);
OP_REQUIRES(context, lower_ >= static_cast<T>(0),
errors::InvalidArgument("Need lower >= 0, got ",
lower_));
errors::InvalidArgument("Need lower >= 0, got ", lower_));
upper_ = static_cast<T>(upper);
OP_REQUIRES(context, upper_ < static_cast<T>(1),
errors::InvalidArgument("Need upper < 1, got ",
upper_));
errors::InvalidArgument("Need upper < 1, got ", upper_));
OP_REQUIRES(
context, lower_ <= upper_,
errors::InvalidArgument("lower must be less than or equal to upper."));
Expand All @@ -140,9 +136,10 @@ class RreluGradOp : public OpKernel {
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
// functor::RreluGrad<Device, T> functor;
functor::RreluGrad<Device, T>()(context->eigen_device<Device>(), gradients.flat<T>(),
input_tensor.flat<T>(), alpha_tensor.flat<T>(), lower_, upper_,
training_, output_tensor->flat<T>());
functor::RreluGrad<Device, T>()(context->eigen_device<Device>(),
gradients.flat<T>(), input_tensor.flat<T>(),
alpha_tensor.flat<T>(), lower_, upper_,
training_, output_tensor->flat<T>());
}

private:
Expand Down

0 comments on commit 2459e94

Please sign in to comment.