diff --git a/tensorflow_addons/activations/gelu.py b/tensorflow_addons/activations/gelu.py index 4525516021..e82c58a03d 100644 --- a/tensorflow_addons/activations/gelu.py +++ b/tensorflow_addons/activations/gelu.py @@ -14,6 +14,7 @@ # ============================================================================== import tensorflow as tf +import math from tensorflow_addons.utils import types from tensorflow_addons.utils.resource_loader import LazySO @@ -49,3 +50,13 @@ def _gelu_grad(op, grad): return _activation_so.ops.addons_gelu_grad( grad, op.inputs[0], op.get_attr("approximate") ) + + +def _gelu_py(x: types.TensorLike, approximate: bool = True) -> tf.Tensor: + x = tf.convert_to_tensor(x) + if approximate: + pi = tf.cast(math.pi, x.dtype) + coeff = tf.cast(0.044715, x.dtype) + return 0.5 * x * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3)))) + else: + return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype))) diff --git a/tensorflow_addons/activations/gelu_test.py b/tensorflow_addons/activations/gelu_test.py index 53c8499eda..da44fb4315 100644 --- a/tensorflow_addons/activations/gelu_test.py +++ b/tensorflow_addons/activations/gelu_test.py @@ -18,6 +18,7 @@ import numpy as np import tensorflow as tf from tensorflow_addons.activations import gelu +from tensorflow_addons.activations.gelu import _gelu_py from tensorflow_addons.utils import test_utils @@ -51,6 +52,25 @@ def test_theoretical_gradients(self, dtype): ) self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4) + @parameterized.named_parameters(("float32", np.float32), ("float64", np.float64)) + def test_same_as_py_func(self, dtype): + np.random.seed(100) + for _ in range(20): + self.verify_funcs_are_equivalent(dtype) + + def verify_funcs_are_equivalent(self, dtype): + x_np = np.random.uniform(-10, 10, size=(4, 4)).astype(dtype) + x = tf.convert_to_tensor(x_np) + for approximate in [True, False]: + with tf.GradientTape(persistent=True) as t: + t.watch(x) + y_native = gelu(x, approximate=approximate) + y_py = _gelu_py(x, approximate=approximate) + self.assertAllCloseAccordingToType(y_native, y_py, atol=1e-4) + grad_native = t.gradient(y_native, x) + grad_py = t.gradient(y_py, x) + self.assertAllCloseAccordingToType(grad_native, grad_py, atol=1e-4) + if __name__ == "__main__": tf.test.main()