Skip to content

Commit

Permalink
typing activation functions (#975)
Browse files Browse the repository at this point in the history
* typing activation functions

* update typing info

* sanity correction
  • Loading branch information
autoih authored Jan 30, 2020
1 parent 4da5e97 commit c8231fb
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 15 deletions.
7 changes: 6 additions & 1 deletion tensorflow_addons/activations/hardshrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@
# ==============================================================================

import tensorflow as tf
from tensorflow_addons.utils.types import Number

from tensorflow_addons.utils import types
from tensorflow_addons.utils.resource_loader import LazySO

_activation_so = LazySO("custom_ops/activations/_activation_ops.so")


@tf.keras.utils.register_keras_serializable(package="Addons")
def hardshrink(x, lower=-0.5, upper=0.5):
def hardshrink(
x: types.TensorLike, lower: Number = -0.5, upper: Number = 0.5
) -> tf.Tensor:
"""Hard shrink function.
Computes hard shrink function:
Expand Down
4 changes: 3 additions & 1 deletion tensorflow_addons/activations/lisht.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
# ==============================================================================

import tensorflow as tf

from tensorflow_addons.utils import types
from tensorflow_addons.utils.resource_loader import LazySO

_activation_so = LazySO("custom_ops/activations/_activation_ops.so")


@tf.keras.utils.register_keras_serializable(package="Addons")
def lisht(x):
def lisht(x: types.TensorLike) -> tf.Tensor:
"""LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function.
Computes linearly scaled hyperbolic tangent (LiSHT): `x * tanh(x)`
Expand Down
4 changes: 3 additions & 1 deletion tensorflow_addons/activations/mish.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
# ==============================================================================

import tensorflow as tf

from tensorflow_addons.utils import types
from tensorflow_addons.utils.resource_loader import LazySO

_activation_so = LazySO("custom_ops/activations/_activation_ops.so")


@tf.keras.utils.register_keras_serializable(package="Addons")
def mish(x):
def mish(x: types.TensorLike) -> tf.Tensor:
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function.
Computes mish activation: x * tanh(softplus(x))
Expand Down
6 changes: 5 additions & 1 deletion tensorflow_addons/activations/rrelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
# ==============================================================================

import tensorflow as tf
from tensorflow_addons.utils.types import Number
from tensorflow_addons.utils import types
from typing import Optional


@tf.keras.utils.register_keras_serializable(package='Addons')
def rrelu(x, lower=0.125, upper=0.3333333333333333, training=None, seed=None):
def rrelu(x: types.TensorLike, lower: Number = 0.125, upper: Number = 0.3333333333333333,
training: Optional[str] = None, seed: Optional[str] = None) -> tf.Tensor :
"""rrelu function.
Computes rrelu function:
Expand Down
7 changes: 6 additions & 1 deletion tensorflow_addons/activations/softshrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@
# ==============================================================================

import tensorflow as tf
from tensorflow_addons.utils.types import Number

from tensorflow_addons.utils import types
from tensorflow_addons.utils.resource_loader import LazySO

_activation_so = LazySO("custom_ops/activations/_activation_ops.so")


@tf.keras.utils.register_keras_serializable(package="Addons")
def softshrink(x, lower=-0.5, upper=0.5):
def softshrink(
x: types.TensorLike, lower: Number = -0.5, upper: Number = 0.5
) -> tf.Tensor:
"""Soft shrink function.
Computes soft shrink function:
Expand Down
4 changes: 3 additions & 1 deletion tensorflow_addons/activations/sparsemax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

import tensorflow as tf

from tensorflow_addons.utils import types


@tf.keras.utils.register_keras_serializable(package="Addons")
@tf.function
def sparsemax(logits, axis=-1):
def sparsemax(logits: types.TensorLike, axis: int = -1) -> tf.Tensor:
"""Sparsemax activation function [1].
For each batch `i` and class `j` we have
Expand Down
4 changes: 3 additions & 1 deletion tensorflow_addons/activations/tanhshrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
# ==============================================================================

import tensorflow as tf

from tensorflow_addons.utils import types
from tensorflow_addons.utils.resource_loader import LazySO

_activation_so = LazySO("custom_ops/activations/_activation_ops.so")


@tf.keras.utils.register_keras_serializable(package="Addons")
def tanhshrink(x):
def tanhshrink(x: types.TensorLike) -> tf.Tensor:
"""Applies the element-wise function: x - tanh(x)
Args:
Expand Down
8 changes: 0 additions & 8 deletions tools/ci_build/verify/check_typing_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,6 @@
# TODO: add types and remove all elements from
# the exception list.
EXCEPTION_LIST = [
tensorflow_addons.activations.hardshrink,
tensorflow_addons.activations.gelu,
tensorflow_addons.activations.lisht,
tensorflow_addons.activations.mish,
tensorflow_addons.activations.tanhshrink,
tensorflow_addons.activations.sparsemax,
tensorflow_addons.activations.softshrink,
tensorflow_addons.activations.rrelu,
tensorflow_addons.callbacks.TQDMProgressBar,
tensorflow_addons.callbacks.TimeStopping,
tensorflow_addons.image.connected_components,
Expand Down

0 comments on commit c8231fb

Please sign in to comment.