From c8231fb17bf2d5bfc71929bb9b180fa825d25207 Mon Sep 17 00:00:00 2001 From: Jhuo IH <41447049+autoih@users.noreply.github.com> Date: Wed, 29 Jan 2020 23:12:52 -0800 Subject: [PATCH] typing activation functions (#975) * typing activation functions * update typing info * sanity correction --- tensorflow_addons/activations/hardshrink.py | 7 ++++++- tensorflow_addons/activations/lisht.py | 4 +++- tensorflow_addons/activations/mish.py | 4 +++- tensorflow_addons/activations/rrelu.py | 6 +++++- tensorflow_addons/activations/softshrink.py | 7 ++++++- tensorflow_addons/activations/sparsemax.py | 4 +++- tensorflow_addons/activations/tanhshrink.py | 4 +++- tools/ci_build/verify/check_typing_info.py | 8 -------- 8 files changed, 29 insertions(+), 15 deletions(-) diff --git a/tensorflow_addons/activations/hardshrink.py b/tensorflow_addons/activations/hardshrink.py index 7f718e815b..5d66a2ef46 100644 --- a/tensorflow_addons/activations/hardshrink.py +++ b/tensorflow_addons/activations/hardshrink.py @@ -14,13 +14,18 @@ # ============================================================================== import tensorflow as tf +from tensorflow_addons.utils.types import Number + +from tensorflow_addons.utils import types from tensorflow_addons.utils.resource_loader import LazySO _activation_so = LazySO("custom_ops/activations/_activation_ops.so") @tf.keras.utils.register_keras_serializable(package="Addons") -def hardshrink(x, lower=-0.5, upper=0.5): +def hardshrink( + x: types.TensorLike, lower: Number = -0.5, upper: Number = 0.5 +) -> tf.Tensor: """Hard shrink function. Computes hard shrink function: diff --git a/tensorflow_addons/activations/lisht.py b/tensorflow_addons/activations/lisht.py index 1e2a046f72..e8f79958ec 100644 --- a/tensorflow_addons/activations/lisht.py +++ b/tensorflow_addons/activations/lisht.py @@ -14,13 +14,15 @@ # ============================================================================== import tensorflow as tf + +from tensorflow_addons.utils import types from tensorflow_addons.utils.resource_loader import LazySO _activation_so = LazySO("custom_ops/activations/_activation_ops.so") @tf.keras.utils.register_keras_serializable(package="Addons") -def lisht(x): +def lisht(x: types.TensorLike) -> tf.Tensor: """LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function. Computes linearly scaled hyperbolic tangent (LiSHT): `x * tanh(x)` diff --git a/tensorflow_addons/activations/mish.py b/tensorflow_addons/activations/mish.py index 714868bd72..b862b122cb 100644 --- a/tensorflow_addons/activations/mish.py +++ b/tensorflow_addons/activations/mish.py @@ -14,13 +14,15 @@ # ============================================================================== import tensorflow as tf + +from tensorflow_addons.utils import types from tensorflow_addons.utils.resource_loader import LazySO _activation_so = LazySO("custom_ops/activations/_activation_ops.so") @tf.keras.utils.register_keras_serializable(package="Addons") -def mish(x): +def mish(x: types.TensorLike) -> tf.Tensor: """Mish: A Self Regularized Non-Monotonic Neural Activation Function. Computes mish activation: x * tanh(softplus(x)) diff --git a/tensorflow_addons/activations/rrelu.py b/tensorflow_addons/activations/rrelu.py index 3d99736dca..ead5ff4c0f 100644 --- a/tensorflow_addons/activations/rrelu.py +++ b/tensorflow_addons/activations/rrelu.py @@ -14,10 +14,14 @@ # ============================================================================== import tensorflow as tf +from tensorflow_addons.utils.types import Number +from tensorflow_addons.utils import types +from typing import Optional @tf.keras.utils.register_keras_serializable(package='Addons') -def rrelu(x, lower=0.125, upper=0.3333333333333333, training=None, seed=None): +def rrelu(x: types.TensorLike, lower: Number = 0.125, upper: Number = 0.3333333333333333, + training: Optional[str] = None, seed: Optional[str] = None) -> tf.Tensor : """rrelu function. Computes rrelu function: diff --git a/tensorflow_addons/activations/softshrink.py b/tensorflow_addons/activations/softshrink.py index f6510beb59..a93faf1a98 100644 --- a/tensorflow_addons/activations/softshrink.py +++ b/tensorflow_addons/activations/softshrink.py @@ -14,13 +14,18 @@ # ============================================================================== import tensorflow as tf +from tensorflow_addons.utils.types import Number + +from tensorflow_addons.utils import types from tensorflow_addons.utils.resource_loader import LazySO _activation_so = LazySO("custom_ops/activations/_activation_ops.so") @tf.keras.utils.register_keras_serializable(package="Addons") -def softshrink(x, lower=-0.5, upper=0.5): +def softshrink( + x: types.TensorLike, lower: Number = -0.5, upper: Number = 0.5 +) -> tf.Tensor: """Soft shrink function. Computes soft shrink function: diff --git a/tensorflow_addons/activations/sparsemax.py b/tensorflow_addons/activations/sparsemax.py index 150ceba6c6..1dec26d58c 100644 --- a/tensorflow_addons/activations/sparsemax.py +++ b/tensorflow_addons/activations/sparsemax.py @@ -15,10 +15,12 @@ import tensorflow as tf +from tensorflow_addons.utils import types + @tf.keras.utils.register_keras_serializable(package="Addons") @tf.function -def sparsemax(logits, axis=-1): +def sparsemax(logits: types.TensorLike, axis: int = -1) -> tf.Tensor: """Sparsemax activation function [1]. For each batch `i` and class `j` we have diff --git a/tensorflow_addons/activations/tanhshrink.py b/tensorflow_addons/activations/tanhshrink.py index e97c6ee613..be41d0e9ee 100644 --- a/tensorflow_addons/activations/tanhshrink.py +++ b/tensorflow_addons/activations/tanhshrink.py @@ -14,13 +14,15 @@ # ============================================================================== import tensorflow as tf + +from tensorflow_addons.utils import types from tensorflow_addons.utils.resource_loader import LazySO _activation_so = LazySO("custom_ops/activations/_activation_ops.so") @tf.keras.utils.register_keras_serializable(package="Addons") -def tanhshrink(x): +def tanhshrink(x: types.TensorLike) -> tf.Tensor: """Applies the element-wise function: x - tanh(x) Args: diff --git a/tools/ci_build/verify/check_typing_info.py b/tools/ci_build/verify/check_typing_info.py index 6510599d2b..6601a97d26 100644 --- a/tools/ci_build/verify/check_typing_info.py +++ b/tools/ci_build/verify/check_typing_info.py @@ -25,14 +25,6 @@ # TODO: add types and remove all elements from # the exception list. EXCEPTION_LIST = [ - tensorflow_addons.activations.hardshrink, - tensorflow_addons.activations.gelu, - tensorflow_addons.activations.lisht, - tensorflow_addons.activations.mish, - tensorflow_addons.activations.tanhshrink, - tensorflow_addons.activations.sparsemax, - tensorflow_addons.activations.softshrink, - tensorflow_addons.activations.rrelu, tensorflow_addons.callbacks.TQDMProgressBar, tensorflow_addons.callbacks.TimeStopping, tensorflow_addons.image.connected_components,