diff --git a/tensorflow_addons/activations/rrelu.py b/tensorflow_addons/activations/rrelu.py index 3309d3cf9e..07a0bbfa74 100644 --- a/tensorflow_addons/activations/rrelu.py +++ b/tensorflow_addons/activations/rrelu.py @@ -31,7 +31,7 @@ def rrelu(x, lower=0.125, upper=0.3333333333333333, training=None, seed=None): """rrelu function. Computes rrelu function: - `x if x > 0 else random(lower,upper) * x` or + `x if x > 0 else random(lower, upper) * x` or `x if x > 0 else x * (lower + upper) / 2` depending on whether training is enabled. @@ -44,6 +44,7 @@ def rrelu(x, lower=0.125, upper=0.3333333333333333, training=None, seed=None): upper: `float`, upper bound for random alpha. training: `bool`, indicating whether the `call` is meant for training or inference. + seed: `int`, this sets the operation-level seed. Returns: result: A `Tensor`. Has the same type as `x`. """ @@ -51,6 +52,7 @@ def rrelu(x, lower=0.125, upper=0.3333333333333333, training=None, seed=None): if training is None: training = tf.keras.backend.learning_phase() training = bool(tf.keras.backend.get_value(training)) + # TODO: get rid of v1 API seed1, seed2 = tf.compat.v1.random.get_seed(seed) result, _ = _activation_ops_so.addons_rrelu(x, lower, upper, training, seed1, seed2) diff --git a/tensorflow_addons/activations/rrelu_test.py b/tensorflow_addons/activations/rrelu_test.py old mode 100755 new mode 100644