Skip to content

Commit

Permalink
change test case
Browse files Browse the repository at this point in the history
  • Loading branch information
fsx950223 committed Oct 8, 2019
1 parent dfbe2c5 commit c673733
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 41 deletions.
2 changes: 1 addition & 1 deletion tensorflow_addons/activations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ py_library(
"gelu.py",
"hardshrink.py",
"lisht.py",
"softshrink.py",
"rrelu.py",
"softshrink.py",
"sparsemax.py",
"tanhshrink.py",
],
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/activations/activations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
class ActivationsTest(tf.test.TestCase):

ALL_ACTIVATIONS = [
"gelu", "hardshrink", "lisht", "softshrink", "sparsemax","rrelu", "tanhshrink"
"gelu", "hardshrink", "lisht", "softshrink", "sparsemax", "rrelu",
"tanhshrink"
]

def test_serialization(self):
Expand Down
15 changes: 9 additions & 6 deletions tensorflow_addons/activations/rrelu.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@

@keras_utils.register_keras_custom_object
@tf.function
def rrelu(x, lower=0.125, upper=0.3333333333333333, training=None):
def rrelu(x, lower=0.125, upper=0.3333333333333333, training=None, seed=None):
"""rrelu function.
Computes rrelu function:
`x if x > 0 else random(lower,upper)*x`.
`x if x > 0 else random(lower,upper) * x` or `x if x > 0 else x * (lower + upper) / 2`
depending on whether training is enabled.
See [Empirical Evaluation of Rectified Activations in Convolutional Network](https://arxiv.org/abs/1505.00853).
Args:
x: A `Tensor`. Must be one of the following types:
Expand All @@ -43,11 +46,10 @@ def rrelu(x, lower=0.125, upper=0.3333333333333333, training=None):
A `Tensor`. Has the same type as `x`.
"""
x = tf.convert_to_tensor(x)
lower = tf.convert_to_tensor(x)
upper = tf.convert_to_tensor(x, dtype=lower.dtype)
if training is None:
training = tf.keras.backend.learning_phase()
return _activation_ops_so.rrelu(x, lower, upper, training)
training = bool(tf.keras.backend.get_value(training))
return _activation_ops_so.addons_rrelu(x, lower, upper, training, seed)


@tf.RegisterGradient("Addons>Rrelu")
Expand All @@ -56,4 +58,5 @@ def _rrelu_grad(op, grad):
op.outputs[1],
op.get_attr("lower"),
op.get_attr("upper"),
op.get_attr("training"))
op.get_attr("training"),
op.get_attr("seed"))
37 changes: 15 additions & 22 deletions tensorflow_addons/activations/rrelu_test.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,30 @@
from tensorflow_addons.utils import test_utils
import random

SEED = 111111
tf.random.set_seed(SEED)
random.seed(SEED)

def _ref_rrelu(x, lower, upper, alpha, training=None):
if training:
return tf.where(x >= 0, x, alpha * x)
else:
return tf.where(x >= 0, x, x * (lower + upper) / 2)


@test_utils.run_all_in_graph_and_eager_modes
class RreluTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
("float64", np.float64))
def test_lisht(self, dtype):
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
expected_result = tf.constant(
[1.9280552, 0.7615942, 0.0, 0.7615942, 1.9280552], dtype=dtype)
self.assertAllCloseAccordingToType(rrelu(x), expected_result)

@parameterized.named_parameters(("float32", np.float32),
("float64", np.float64))
def test_theoretical_gradients(self, dtype):
# Only test theoretical gradients for float32 and float64
# because of the instability of float16 while computing jacobian
def test_rrelu_training(self, dtype):
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
lower = 0.1
upper = 0.2
# result,alpha=rrelu(x,lower,upper, training=True)
# self.assertAllCloseAccordingToType(result, _ref_rrelu(x,lower,upper,alpha,training=True))

theoretical, numerical = tf.test.compute_gradient(rrelu, [x])
result, alpha = rrelu(x, lower, upper, training=False)
self.assertAllCloseAccordingToType(
theoretical, numerical, rtol=5e-4, atol=5e-4)
result, _ref_rrelu(x, lower, upper, alpha, training=False))

def test_unknown_shape(self):
fn = rrelu.get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32))

for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]:
x = tf.ones(shape=shape, dtype=tf.float32)
self.assertAllClose(fn(x), rrelu(x))
if __name__ == "__main__":
tf.test.main()
6 changes: 3 additions & 3 deletions tensorflow_addons/custom_ops/activations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,17 @@ cc_binary(
"cc/kernels/hardshrink_op.h",
"cc/kernels/lisht_op.cc",
"cc/kernels/lisht_op.h",
"cc/kernels/softshrink_op.cc",
"cc/kernels/softshrink_op.h",
"cc/kernels/rrelu_op.cc",
"cc/kernels/rrelu_op.h",
"cc/kernels/softshrink_op.cc",
"cc/kernels/softshrink_op.h",
"cc/kernels/tanhshrink_op.cc",
"cc/kernels/tanhshrink_op.h",
"cc/ops/gelu_op.cc",
"cc/ops/hardshrink_op.cc",
"cc/ops/lisht_op.cc",
"cc/ops/softshrink_op.cc",
"cc/ops/rrelu_op.cc",
"cc/ops/softshrink_op.cc",
"cc/ops/tanhshrink_op.cc",
],
copts = [
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,5 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_RRELU_GPU_KERNELS);
#undef REGISTER_RRELU_GPU_KERNELS

#endif // GOOGLE_CUDA
} // end namespace addons
} // namespace addons
} // namespace tensorflow
15 changes: 9 additions & 6 deletions tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.

#define EIGEN_USE_THREADS

#include <cstdlib>
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
Expand All @@ -43,8 +44,9 @@ struct Rrelu {
} else {
activations.device(d) =
(features >= static_cast<T>(0))
.select(features, features.constant(static_cast<T>(2)) *
features / (lower + upper));
.select(features, features *
features.constant((lower + upper) /
static_cast<T>(2)));
}
}
};
Expand All @@ -65,7 +67,7 @@ struct RreluGrad {
gradients *
(features >= static_cast<T>(0))
.select(features.constant(static_cast<T>(1)),
features.constant(static_cast<T>(2) / (lower + upper)));
features.constant((lower + upper) / static_cast<T>(2)));
}
}
};
Expand All @@ -80,6 +82,7 @@ class RreluOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("lower", &lower));
OP_REQUIRES_OK(context, context->GetAttr("upper", &upper));
OP_REQUIRES_OK(context, context->GetAttr("training", &training_));
// OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
lower_ = static_cast<T>(lower);
OP_REQUIRES(context, lower_ >= static_cast<T>(0),
errors::InvalidArgument("Need lower >= 0, got ", lower_));
Expand All @@ -98,7 +101,7 @@ class RreluOp : public OpKernel {
&output_tensor));
OP_REQUIRES_OK(context, context->allocate_output(1, input_tensor.shape(),
&alpha_tensor));
// functor::Rrelu<Device, T> functor;
// std::srand(seed_);
functor::Rrelu<Device, T>()(
context->eigen_device<Device>(), input_tensor.flat<T>(), lower_, upper_,
training_, output_tensor->flat<T>(), alpha_tensor->flat<T>());
Expand All @@ -108,6 +111,7 @@ class RreluOp : public OpKernel {
T lower_;
T upper_;
bool training_;
int seed_;
};

template <typename Device, typename T>
Expand Down Expand Up @@ -135,7 +139,6 @@ class RreluGradOp : public OpKernel {
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
// functor::RreluGrad<Device, T> functor;
functor::RreluGrad<Device, T>()(context->eigen_device<Device>(),
gradients.flat<T>(), input_tensor.flat<T>(),
alpha_tensor.flat<T>(), lower_, upper_,
Expand All @@ -148,7 +151,7 @@ class RreluGradOp : public OpKernel {
bool training_;
};

} // end namespace addons
} // namespace addons
} // namespace tensorflow

#undef EIGEN_USE_THREADS
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op_gpu.cu.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ using GPUDevice = Eigen::GpuDevice;

TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);

} // end namespace addons
} // namespace addons
} // namespace tensorflow

#endif // GOOGLE_CUDA
2 changes: 2 additions & 0 deletions tensorflow_addons/custom_ops/activations/cc/ops/rrelu_op.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ REGISTER_OP("Addons>Rrelu")
.Attr("lower: float")
.Attr("upper: float")
.Attr("training: bool")
// .Attr("seed: int")
.SetShapeFn(shape_inference::UnchangedShape);

REGISTER_OP("Addons>RreluGrad")
Expand All @@ -39,6 +40,7 @@ REGISTER_OP("Addons>RreluGrad")
.Attr("lower: float")
.Attr("upper: float")
.Attr("training: bool")
// .Attr("seed: int")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);

} // end namespace addons
Expand Down

0 comments on commit c673733

Please sign in to comment.