diff --git a/tensorflow_addons/image/tests/distort_image_ops_test.py b/tensorflow_addons/image/tests/distort_image_ops_test.py
index 661c1454ab..ee7f801174 100644
--- a/tensorflow_addons/image/tests/distort_image_ops_test.py
+++ b/tensorflow_addons/image/tests/distort_image_ops_test.py
@@ -94,7 +94,7 @@ def test_adjust_random_hue_in_yiq(shape, style, dtype):
     y_np = _adjust_hue_in_yiq_np(x_np, delta_h)
     y_tf = _adjust_hue_in_yiq_tf(x_np, delta_h)
     test_utils.assert_allclose_according_to_type(
-        y_tf, y_np, atol=1e-4, rtol=2e-4, half_rtol=0.8
+        y_tf, y_np, atol=1e-4, rtol=2e-4, half_rtol=1.1
     )
 
 
@@ -121,11 +121,11 @@ def test_invalid_channels_hsv():
 
 def test_adjust_hsv_in_yiq_unknown_shape():
     fn = tf.function(distort_image_ops.adjust_hsv_in_yiq).get_concrete_function(
-        tf.TensorSpec(shape=None, dtype=tf.float64)
+        tf.TensorSpec(shape=None, dtype=tf.float32)
     )
     for shape in (2, 3, 3), (4, 2, 3, 3):
         image_np = np.random.rand(*shape) * 255.0
-        image_tf = tf.constant(image_np)
+        image_tf = tf.constant(image_np, dtype=tf.float32)
         np.testing.assert_allclose(
             _adjust_hue_in_yiq_np(image_np, 0), fn(image_tf), rtol=2e-4, atol=1e-4
         )
diff --git a/tensorflow_addons/optimizers/discriminative_layer_training.py b/tensorflow_addons/optimizers/discriminative_layer_training.py
index a82f1b2d3e..d41c5ee997 100644
--- a/tensorflow_addons/optimizers/discriminative_layer_training.py
+++ b/tensorflow_addons/optimizers/discriminative_layer_training.py
@@ -22,9 +22,20 @@
 from tensorflow_addons.optimizers import KerasLegacyOptimizer
 from typeguard import typechecked
 
-if Version(tf.__version__).release >= Version("2.13").release:
-    # New versions of Keras require importing from `keras.src` when
-    # importing internal symbols.
+if Version(tf.__version__).release >= Version("2.16").release:
+    # Determine if loading keras 2 or 3.
+    if (
+        hasattr(tf.keras, "version")
+        and Version(tf.keras.version()).release >= Version("3.0").release
+    ):
+        # New versions of Keras require importing from `keras.src` when
+        # importing internal symbols.
+        from keras.src import backend
+        from keras.src.utils import tf_utils
+    else:
+        from tf_keras.src import backend
+        from tf_keras.src.utils import tf_utils
+elif Version(tf.__version__).release >= Version("2.13").release:
     from keras.src import backend
     from keras.src.utils import tf_utils
 else:
diff --git a/tensorflow_addons/optimizers/lazy_adam.py b/tensorflow_addons/optimizers/lazy_adam.py
index ad8570bc3c..b09e4e96ad 100644
--- a/tensorflow_addons/optimizers/lazy_adam.py
+++ b/tensorflow_addons/optimizers/lazy_adam.py
@@ -149,3 +149,6 @@ def _resource_scatter_operate(self, resource, indices, update, resource_scatter_
         }
 
         return resource_scatter_op(**resource_update_kwargs)
+
+    def get_config(self):
+        return super().get_config()
diff --git a/tensorflow_addons/rnn/abstract_rnn_cell.py b/tensorflow_addons/rnn/abstract_rnn_cell.py
new file mode 100644
index 0000000000..de5225bf76
--- /dev/null
+++ b/tensorflow_addons/rnn/abstract_rnn_cell.py
@@ -0,0 +1,133 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for RNN cells.
+
+Adapted from legacy github.com/keras-team/tf-keras.
+"""
+
+import tensorflow as tf
+
+
+def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype):
+    if inputs is not None:
+        batch_size = tf.shape(inputs)[0]
+        dtype = inputs.dtype
+    return _generate_zero_filled_state(batch_size, cell.state_size, dtype)
+
+
+def _generate_zero_filled_state(batch_size_tensor, state_size, dtype):
+    """Generate a zero filled tensor with shape [batch_size, state_size]."""
+    if batch_size_tensor is None or dtype is None:
+        raise ValueError(
+            "batch_size and dtype cannot be None while constructing initial state: "
+            "batch_size={}, dtype={}".format(batch_size_tensor, dtype)
+        )
+
+    def create_zeros(unnested_state_size):
+        flat_dims = tf.TensorShape(unnested_state_size).as_list()
+        init_state_size = [batch_size_tensor] + flat_dims
+        return tf.zeros(init_state_size, dtype=dtype)
+
+    if tf.nest.is_nested(state_size):
+        return tf.nest.map_structure(create_zeros, state_size)
+    else:
+        return create_zeros(state_size)
+
+
+class AbstractRNNCell(tf.keras.layers.Layer):
+    """Abstract object representing an RNN cell.
+
+    This is a base class for implementing RNN cells with custom behavior.
+
+    Every `RNNCell` must have the properties below and implement `call` with
+    the signature `(output, next_state) = call(input, state)`.
+
+    Examples:
+
+    ```python
+      class MinimalRNNCell(AbstractRNNCell):
+
+        def __init__(self, units, **kwargs):
+          self.units = units
+          super(MinimalRNNCell, self).__init__(**kwargs)
+
+        @property
+        def state_size(self):
+          return self.units
+
+        def build(self, input_shape):
+          self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
+                                        initializer='uniform',
+                                        name='kernel')
+          self.recurrent_kernel = self.add_weight(
+              shape=(self.units, self.units),
+              initializer='uniform',
+              name='recurrent_kernel')
+          self.built = True
+
+        def call(self, inputs, states):
+          prev_output = states[0]
+          h = backend.dot(inputs, self.kernel)
+          output = h + backend.dot(prev_output, self.recurrent_kernel)
+          return output, output
+    ```
+
+    This definition of cell differs from the definition used in the literature.
+    In the literature, 'cell' refers to an object with a single scalar output.
+    This definition refers to a horizontal array of such units.
+
+    An RNN cell, in the most abstract setting, is anything that has
+    a state and performs some operation that takes a matrix of inputs.
+    This operation results in an output matrix with `self.output_size` columns.
+    If `self.state_size` is an integer, this operation also results in a new
+    state matrix with `self.state_size` columns.  If `self.state_size` is a
+    (possibly nested tuple of) TensorShape object(s), then it should return a
+    matching structure of Tensors having shape `[batch_size].concatenate(s)`
+    for each `s` in `self.batch_size`.
+    """
+
+    def call(self, inputs, states):
+        """The function that contains the logic for one RNN step calculation.
+
+        Args:
+          inputs: the input tensor, which is a slide from the overall RNN input by
+            the time dimension (usually the second dimension).
+          states: the state tensor from previous step, which has the same shape
+            as `(batch, state_size)`. In the case of timestep 0, it will be the
+            initial state user specified, or zero filled tensor otherwise.
+
+        Returns:
+          A tuple of two tensors:
+            1. output tensor for the current timestep, with size `output_size`.
+            2. state tensor for next step, which has the shape of `state_size`.
+        """
+        raise NotImplementedError("Abstract method")
+
+    @property
+    def state_size(self):
+        """size(s) of state(s) used by this cell.
+
+        It can be represented by an Integer, a TensorShape or a tuple of Integers
+        or TensorShapes.
+        """
+        raise NotImplementedError("Abstract method")
+
+    @property
+    def output_size(self):
+        """Integer or TensorShape: size of outputs produced by this cell."""
+        raise NotImplementedError("Abstract method")
+
+    def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
+        return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
diff --git a/tensorflow_addons/rnn/esn_cell.py b/tensorflow_addons/rnn/esn_cell.py
index 835da96e98..2147c07de8 100644
--- a/tensorflow_addons/rnn/esn_cell.py
+++ b/tensorflow_addons/rnn/esn_cell.py
@@ -15,9 +15,9 @@
 """Implements ESN Cell."""
 
 import tensorflow as tf
-import tensorflow.keras as keras
 from typeguard import typechecked
 
+from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
 from tensorflow_addons.utils.types import (
     Activation,
     Initializer,
@@ -25,7 +25,7 @@
 
 
 @tf.keras.utils.register_keras_serializable(package="Addons")
-class ESNCell(keras.layers.AbstractRNNCell):
+class ESNCell(AbstractRNNCell):
     """Echo State recurrent Network (ESN) cell.
     This implements the recurrent cell from the paper:
         H. Jaeger
diff --git a/tensorflow_addons/rnn/nas_cell.py b/tensorflow_addons/rnn/nas_cell.py
index ce6ca766ce..f5304d1c12 100644
--- a/tensorflow_addons/rnn/nas_cell.py
+++ b/tensorflow_addons/rnn/nas_cell.py
@@ -15,9 +15,9 @@
 """Implements NAS Cell."""
 
 import tensorflow as tf
-import tensorflow.keras as keras
 from typeguard import typechecked
 
+from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
 from tensorflow_addons.utils.types import (
     FloatTensorLike,
     TensorLike,
@@ -27,7 +27,7 @@
 
 
 @tf.keras.utils.register_keras_serializable(package="Addons")
-class NASCell(keras.layers.AbstractRNNCell):
+class NASCell(AbstractRNNCell):
     """Neural Architecture Search (NAS) recurrent network cell.
 
     This implements the recurrent cell from the paper:
diff --git a/tensorflow_addons/seq2seq/BUILD b/tensorflow_addons/seq2seq/BUILD
index 0674740e58..8f7b8470b3 100644
--- a/tensorflow_addons/seq2seq/BUILD
+++ b/tensorflow_addons/seq2seq/BUILD
@@ -10,6 +10,7 @@ py_library(
         "//tensorflow_addons/custom_ops/seq2seq:_beam_search_ops.so",
     ],
     deps = [
+        "//tensorflow_addons/rnn",
         "//tensorflow_addons/testing",
         "//tensorflow_addons/utils",
     ],
diff --git a/tensorflow_addons/seq2seq/attention_wrapper.py b/tensorflow_addons/seq2seq/attention_wrapper.py
index b1b6f93f2d..d44cbe6a47 100644
--- a/tensorflow_addons/seq2seq/attention_wrapper.py
+++ b/tensorflow_addons/seq2seq/attention_wrapper.py
@@ -23,6 +23,7 @@
 
 import tensorflow as tf
 
+from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
 from tensorflow_addons.utils import keras_utils
 from tensorflow_addons.utils.types import (
     AcceptableDTypes,
@@ -1577,7 +1578,7 @@ def _compute_attention(
     return attention, alignments, next_attention_state
 
 
-class AttentionWrapper(tf.keras.layers.AbstractRNNCell):
+class AttentionWrapper(AbstractRNNCell):
     """Wraps another RNN cell with attention.
 
     Example:
diff --git a/tensorflow_addons/text/BUILD b/tensorflow_addons/text/BUILD
index ae4005d391..79afb5637f 100644
--- a/tensorflow_addons/text/BUILD
+++ b/tensorflow_addons/text/BUILD
@@ -7,17 +7,15 @@ package(default_visibility = ["//visibility:public"])
 py_library(
     name = "text",
     srcs = glob(["*.py"]),
-    data = select({
-        "//tensorflow_addons:windows": [
-            "//tensorflow_addons/custom_ops/text:_skip_gram_ops.so",
-            "//tensorflow_addons/testing",
-            "//tensorflow_addons/utils",
-        ],
+    data = [
+        "//tensorflow_addons/custom_ops/text:_skip_gram_ops.so",
+        "//tensorflow_addons/rnn",
+        "//tensorflow_addons/testing",
+        "//tensorflow_addons/utils",
+    ] + select({
+        "//tensorflow_addons:windows": [],
         "//conditions:default": [
             "//tensorflow_addons/custom_ops/text:_parse_time_op.so",
-            "//tensorflow_addons/custom_ops/text:_skip_gram_ops.so",
-            "//tensorflow_addons/testing",
-            "//tensorflow_addons/utils",
         ],
     }),
 )
diff --git a/tensorflow_addons/text/crf.py b/tensorflow_addons/text/crf.py
index 3820b08a94..287481e546 100644
--- a/tensorflow_addons/text/crf.py
+++ b/tensorflow_addons/text/crf.py
@@ -17,6 +17,7 @@
 import numpy as np
 import tensorflow as tf
 
+from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
 from tensorflow_addons.utils.types import TensorLike
 from typeguard import typechecked
 from typing import Optional, Tuple
@@ -403,7 +404,7 @@ def viterbi_decode(score: TensorLike, transition_params: TensorLike) -> tf.Tenso
     return viterbi, viterbi_score
 
 
-class CrfDecodeForwardRnnCell(tf.keras.layers.AbstractRNNCell):
+class CrfDecodeForwardRnnCell(AbstractRNNCell):
     """Computes the forward decoding in a linear-chain CRF."""
 
     @typechecked
diff --git a/tensorflow_addons/utils/test_utils.py b/tensorflow_addons/utils/test_utils.py
index f998fb4a45..31a43a5536 100644
--- a/tensorflow_addons/utils/test_utils.py
+++ b/tensorflow_addons/utils/test_utils.py
@@ -22,18 +22,10 @@
 import pytest
 import tensorflow as tf
 
-from packaging.version import Version
 from tensorflow_addons import options
 from tensorflow_addons.utils import resource_loader
 
-if Version(tf.__version__).release >= Version("2.13").release:
-    # New versions of Keras require importing from `keras.src` when
-    # importing internal symbols.
-    from keras.src.testing_infra.test_utils import layer_test  # noqa: F401
-elif Version(tf.__version__) >= Version("2.9"):
-    from keras.testing_infra.test_utils import layer_test  # noqa: F401
-else:
-    from keras.testing_utils import layer_test  # noqa: F401
+from tensorflow_addons.utils.tf_test_utils import layer_test  # noqa
 
 NUMBER_OF_WORKERS = int(os.environ.get("PYTEST_XDIST_WORKER_COUNT", "1"))
 WORKER_ID = int(os.environ.get("PYTEST_XDIST_WORKER", "gw0")[2])
diff --git a/tensorflow_addons/utils/tf_inspect.py b/tensorflow_addons/utils/tf_inspect.py
new file mode 100644
index 0000000000..8ca132091a
--- /dev/null
+++ b/tensorflow_addons/utils/tf_inspect.py
@@ -0,0 +1,282 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TFDecorator-aware replacements for the inspect module."""
+import collections
+import functools
+import inspect as _inspect
+
+import tensorflow as tf
+
+if hasattr(_inspect, "ArgSpec"):
+    ArgSpec = _inspect.ArgSpec
+else:
+    ArgSpec = collections.namedtuple(
+        "ArgSpec",
+        [
+            "args",
+            "varargs",
+            "keywords",
+            "defaults",
+        ],
+    )
+
+if hasattr(_inspect, "FullArgSpec"):
+    FullArgSpec = _inspect.FullArgSpec
+else:
+    FullArgSpec = collections.namedtuple(
+        "FullArgSpec",
+        [
+            "args",
+            "varargs",
+            "varkw",
+            "defaults",
+            "kwonlyargs",
+            "kwonlydefaults",
+            "annotations",
+        ],
+    )
+
+
+def _convert_maybe_argspec_to_fullargspec(argspec):
+    if isinstance(argspec, FullArgSpec):
+        return argspec
+    return FullArgSpec(
+        args=argspec.args,
+        varargs=argspec.varargs,
+        varkw=argspec.keywords,
+        defaults=argspec.defaults,
+        kwonlyargs=[],
+        kwonlydefaults=None,
+        annotations={},
+    )
+
+
+if hasattr(_inspect, "getfullargspec"):
+    _getfullargspec = _inspect.getfullargspec
+
+    def _getargspec(target):
+        """A python3 version of getargspec.
+
+        Calls `getfullargspec` and assigns args, varargs,
+        varkw, and defaults to a python 2/3 compatible `ArgSpec`.
+
+        The parameter name 'varkw' is changed to 'keywords' to fit the
+        `ArgSpec` struct.
+
+        Args:
+          target: the target object to inspect.
+
+        Returns:
+          An ArgSpec with args, varargs, keywords, and defaults parameters
+          from FullArgSpec.
+        """
+        fullargspecs = getfullargspec(target)
+        argspecs = ArgSpec(
+            args=fullargspecs.args,
+            varargs=fullargspecs.varargs,
+            keywords=fullargspecs.varkw,
+            defaults=fullargspecs.defaults,
+        )
+        return argspecs
+
+else:
+    _getargspec = _inspect.getargspec
+
+    def _getfullargspec(target):
+        """A python2 version of getfullargspec.
+
+        Args:
+          target: the target object to inspect.
+
+        Returns:
+          A FullArgSpec with empty kwonlyargs, kwonlydefaults and annotations.
+        """
+        return _convert_maybe_argspec_to_fullargspec(getargspec(target))
+
+
+def currentframe():
+    """TFDecorator-aware replacement for inspect.currentframe."""
+    return _inspect.stack()[1][0]
+
+
+def getargspec(obj):
+    """TFDecorator-aware replacement for `inspect.getargspec`.
+
+    Note: `getfullargspec` is recommended as the python 2/3 compatible
+    replacement for this function.
+
+    Args:
+      obj: A function, partial function, or callable object, possibly decorated.
+
+    Returns:
+      The `ArgSpec` that describes the signature of the outermost decorator that
+      changes the callable's signature, or the `ArgSpec` that describes
+      the object if not decorated.
+
+    Raises:
+      ValueError: When callable's signature can not be expressed with
+        ArgSpec.
+      TypeError: For objects of unsupported types.
+    """
+    if isinstance(obj, functools.partial):
+        return _get_argspec_for_partial(obj)
+
+    decorators, target = tf.__internal__.decorator.unwrap(obj)
+
+    spec = next(
+        (d.decorator_argspec for d in decorators if d.decorator_argspec is not None),
+        None,
+    )
+    if spec:
+        return spec
+
+    try:
+        # Python3 will handle most callables here (not partial).
+        return _getargspec(target)
+    except TypeError:
+        pass
+
+    if isinstance(target, type):
+        try:
+            return _getargspec(target.__init__)
+        except TypeError:
+            pass
+
+        try:
+            return _getargspec(target.__new__)
+        except TypeError:
+            pass
+
+    # The `type(target)` ensures that if a class is received we don't return
+    # the signature of its __call__ method.
+    return _getargspec(type(target).__call__)
+
+
+def _get_argspec_for_partial(obj):
+    """Implements `getargspec` for `functools.partial` objects.
+
+    Args:
+      obj: The `functools.partial` object
+    Returns:
+      An `inspect.ArgSpec`
+    Raises:
+      ValueError: When callable's signature can not be expressed with
+        ArgSpec.
+    """
+    # When callable is a functools.partial object, we construct its ArgSpec with
+    # following strategy:
+    # - If callable partial contains default value for positional arguments (ie.
+    # object.args), then final ArgSpec doesn't contain those positional
+    # arguments.
+    # - If callable partial contains default value for keyword arguments (ie.
+    # object.keywords), then we merge them with wrapped target. Default values
+    # from callable partial takes precedence over those from wrapped target.
+    #
+    # However, there is a case where it is impossible to construct a valid
+    # ArgSpec. Python requires arguments that have no default values must be
+    # defined before those with default values. ArgSpec structure is only valid
+    # when this presumption holds true because default values are expressed as a
+    # tuple of values without keywords and they are always assumed to belong to
+    # last K arguments where K is number of default values present.
+    #
+    # Since functools.partial can give default value to any argument, this
+    # presumption may no longer hold in some cases. For example:
+    #
+    # def func(m, n):
+    #   return 2 * m + n
+    # partialed = functools.partial(func, m=1)
+    #
+    # This example will result in m having a default value but n doesn't. This
+    # is usually not allowed in Python and can not be expressed in ArgSpec
+    # correctly.
+    #
+    # Thus, we must detect cases like this by finding first argument with
+    # default value and ensures all following arguments also have default
+    # values. When this is not true, a ValueError is raised.
+
+    n_prune_args = len(obj.args)
+    partial_keywords = obj.keywords or {}
+
+    args, varargs, keywords, defaults = getargspec(obj.func)
+
+    # Pruning first n_prune_args arguments.
+    args = args[n_prune_args:]
+
+    # Partial function may give default value to any argument, therefore length
+    # of default value list must be len(args) to allow each argument to
+    # potentially be given a default value.
+    no_default = object()
+    all_defaults = [no_default] * len(args)
+
+    if defaults:
+        all_defaults[-len(defaults) :] = defaults
+
+    # Fill in default values provided by partial function in all_defaults.
+    for kw, default in partial_keywords.items():
+        if kw in args:
+            idx = args.index(kw)
+            all_defaults[idx] = default
+        elif not keywords:
+            raise ValueError(
+                "Function does not have **kwargs parameter, but "
+                "contains an unknown partial keyword."
+            )
+
+    # Find first argument with default value set.
+    first_default = next(
+        (idx for idx, x in enumerate(all_defaults) if x is not no_default), None
+    )
+
+    # If no default values are found, return ArgSpec with defaults=None.
+    if first_default is None:
+        return ArgSpec(args, varargs, keywords, None)
+
+    # Checks if all arguments have default value set after first one.
+    invalid_default_values = [
+        args[i]
+        for i, j in enumerate(all_defaults)
+        if j is no_default and i > first_default
+    ]
+
+    if invalid_default_values:
+        raise ValueError(
+            f"Some arguments {invalid_default_values} do not have "
+            "default value, but they are positioned after those with "
+            "default values. This can not be expressed with ArgSpec."
+        )
+
+    return ArgSpec(args, varargs, keywords, tuple(all_defaults[first_default:]))
+
+
+def getfullargspec(obj):
+    """TFDecorator-aware replacement for `inspect.getfullargspec`.
+
+    This wrapper emulates `inspect.getfullargspec` in[^)]* Python2.
+
+    Args:
+      obj: A callable, possibly decorated.
+
+    Returns:
+      The `FullArgSpec` that describes the signature of
+      the outermost decorator that changes the callable's signature. If the
+      callable is not decorated, `inspect.getfullargspec()` will be called
+      directly on the callable.
+    """
+    decorators, target = tf.__internal__.decorator.unwrap(obj)
+
+    for d in decorators:
+        if d.decorator_argspec is not None:
+            return _convert_maybe_argspec_to_fullargspec(d.decorator_argspec)
+    return _getfullargspec(target)
diff --git a/tensorflow_addons/utils/tf_test_utils.py b/tensorflow_addons/utils/tf_test_utils.py
new file mode 100644
index 0000000000..bf965d04f2
--- /dev/null
+++ b/tensorflow_addons/utils/tf_test_utils.py
@@ -0,0 +1,290 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for unit-testing TF-Keras."""
+
+
+import threading
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.keras import backend
+from tensorflow.keras import layers
+from tensorflow.keras import models
+from tensorflow_addons.utils import tf_inspect
+
+
+def string_test(actual, expected):
+    np.testing.assert_array_equal(actual, expected)
+
+
+def numeric_test(actual, expected):
+    np.testing.assert_allclose(actual, expected, rtol=1e-3, atol=1e-6)
+
+
+def layer_test(
+    layer_cls,
+    kwargs=None,
+    input_shape=None,
+    input_dtype=None,
+    input_data=None,
+    expected_output=None,
+    expected_output_dtype=None,
+    expected_output_shape=None,
+    validate_training=True,
+    adapt_data=None,
+    custom_objects=None,
+    test_harness=None,
+    supports_masking=None,
+):
+    """Test routine for a layer with a single input and single output.
+
+    Args:
+      layer_cls: Layer class object.
+      kwargs: Optional dictionary of keyword arguments for instantiating the
+        layer.
+      input_shape: Input shape tuple.
+      input_dtype: Data type of the input data.
+      input_data: Numpy array of input data.
+      expected_output: Numpy array of the expected output.
+      expected_output_dtype: Data type expected for the output.
+      expected_output_shape: Shape tuple for the expected shape of the output.
+      validate_training: Whether to attempt to validate training on this layer.
+        This might be set to False for non-differentiable layers that output
+        string or integer values.
+      adapt_data: Optional data for an 'adapt' call. If None, adapt() will not
+        be tested for this layer. This is only relevant for PreprocessingLayers.
+      custom_objects: Optional dictionary mapping name strings to custom objects
+        in the layer class. This is helpful for testing custom layers.
+      test_harness: The Tensorflow test, if any, that this function is being
+        called in.
+      supports_masking: Optional boolean to check the `supports_masking`
+        property of the layer. If None, the check will not be performed.
+
+    Returns:
+      The output data (Numpy array) returned by the layer, for additional
+      checks to be done by the calling code.
+
+    Raises:
+      ValueError: if `input_shape is None`.
+    """
+    if input_data is None:
+        if input_shape is None:
+            raise ValueError("input_shape is None")
+        if not input_dtype:
+            input_dtype = "float32"
+        input_data_shape = list(input_shape)
+        for i, e in enumerate(input_data_shape):
+            if e is None:
+                input_data_shape[i] = np.random.randint(1, 4)
+        input_data = 10 * np.random.random(input_data_shape)
+        if input_dtype[:5] == "float":
+            input_data -= 0.5
+        input_data = input_data.astype(input_dtype)
+    elif input_shape is None:
+        input_shape = input_data.shape
+    if input_dtype is None:
+        input_dtype = input_data.dtype
+    if expected_output_dtype is None:
+        expected_output_dtype = input_dtype
+
+    if tf.as_dtype(expected_output_dtype) == tf.string:
+        if test_harness:
+            assert_equal = test_harness.assertAllEqual
+        else:
+            assert_equal = string_test
+    else:
+        if test_harness:
+            assert_equal = test_harness.assertAllClose
+        else:
+            assert_equal = numeric_test
+
+    # instantiation
+    kwargs = kwargs or {}
+    layer = layer_cls(**kwargs)
+
+    if supports_masking is not None and layer.supports_masking != supports_masking:
+        raise AssertionError(
+            "When testing layer %s, the `supports_masking` property is %r"
+            "but expected to be %r.\nFull kwargs: %s"
+            % (
+                layer_cls.__name__,
+                layer.supports_masking,
+                supports_masking,
+                kwargs,
+            )
+        )
+
+    # Test adapt, if data was passed.
+    if adapt_data is not None:
+        layer.adapt(adapt_data)
+
+    # test get_weights , set_weights at layer level
+    weights = layer.get_weights()
+    layer.set_weights(weights)
+
+    # test and instantiation from weights
+    if "weights" in tf_inspect.getargspec(layer_cls.__init__):
+        kwargs["weights"] = weights
+        layer = layer_cls(**kwargs)
+
+    # test in functional API
+    x = layers.Input(shape=input_shape[1:], dtype=input_dtype)
+    y = layer(x)
+    if backend.dtype(y) != expected_output_dtype:
+        raise AssertionError(
+            "When testing layer %s, for input %s, found output "
+            "dtype=%s but expected to find %s.\nFull kwargs: %s"
+            % (
+                layer_cls.__name__,
+                x,
+                backend.dtype(y),
+                expected_output_dtype,
+                kwargs,
+            )
+        )
+
+    def assert_shapes_equal(expected, actual):
+        """Asserts that the output shape from the layer matches the actual
+        shape."""
+        if len(expected) != len(actual):
+            raise AssertionError(
+                "When testing layer %s, for input %s, found output_shape="
+                "%s but expected to find %s.\nFull kwargs: %s"
+                % (layer_cls.__name__, x, actual, expected, kwargs)
+            )
+
+        for expected_dim, actual_dim in zip(expected, actual):
+            if expected_dim is not None and expected_dim != actual_dim:
+                raise AssertionError(
+                    "When testing layer %s, for input %s, found output_shape="
+                    "%s but expected to find %s.\nFull kwargs: %s"
+                    % (layer_cls.__name__, x, actual, expected, kwargs)
+                )
+
+    if expected_output_shape is not None:
+        assert_shapes_equal(tf.TensorShape(expected_output_shape), y.shape)
+
+    # check shape inference
+    model = models.Model(x, y)
+    computed_output_shape = tuple(
+        layer.compute_output_shape(tf.TensorShape(input_shape)).as_list()
+    )
+    computed_output_signature = layer.compute_output_signature(
+        tf.TensorSpec(shape=input_shape, dtype=input_dtype)
+    )
+    actual_output = model.predict(input_data)
+    actual_output_shape = actual_output.shape
+    assert_shapes_equal(computed_output_shape, actual_output_shape)
+    assert_shapes_equal(computed_output_signature.shape, actual_output_shape)
+    if computed_output_signature.dtype != actual_output.dtype:
+        raise AssertionError(
+            "When testing layer %s, for input %s, found output_dtype="
+            "%s but expected to find %s.\nFull kwargs: %s"
+            % (
+                layer_cls.__name__,
+                x,
+                actual_output.dtype,
+                computed_output_signature.dtype,
+                kwargs,
+            )
+        )
+    if expected_output is not None:
+        assert_equal(actual_output, expected_output)
+
+    # test serialization, weight setting at model level
+    model_config = model.get_config()
+    recovered_model = models.Model.from_config(model_config, custom_objects)
+    if model.weights:
+        weights = model.get_weights()
+        recovered_model.set_weights(weights)
+        output = recovered_model.predict(input_data)
+        assert_equal(output, actual_output)
+
+    # test training mode (e.g. useful for dropout tests)
+    # Rebuild the model to avoid the graph being reused between predict() and
+    # See b/120160788 for more details. This should be mitigated after 2.0.
+    layer_weights = layer.get_weights()  # Get the layer weights BEFORE training.
+    if validate_training:
+        model = models.Model(x, layer(x))
+        if _thread_local_data.run_eagerly is not None:
+            model.compile(
+                "rmsprop",
+                "mse",
+                weighted_metrics=["acc"],
+                run_eagerly=should_run_eagerly(),
+            )
+        else:
+            model.compile("rmsprop", "mse", weighted_metrics=["acc"])
+        model.train_on_batch(input_data, actual_output)
+
+    # test as first layer in Sequential API
+    layer_config = layer.get_config()
+    layer_config["batch_input_shape"] = input_shape
+    layer = layer.__class__.from_config(layer_config)
+
+    # Test adapt, if data was passed.
+    if adapt_data is not None:
+        layer.adapt(adapt_data)
+
+    model = models.Sequential()
+    model.add(layers.Input(shape=input_shape[1:], dtype=input_dtype))
+    model.add(layer)
+
+    layer.set_weights(layer_weights)
+    actual_output = model.predict(input_data)
+    actual_output_shape = actual_output.shape
+    for expected_dim, actual_dim in zip(computed_output_shape, actual_output_shape):
+        if expected_dim is not None:
+            if expected_dim != actual_dim:
+                raise AssertionError(
+                    "When testing layer %s **after deserialization**, "
+                    "for input %s, found output_shape="
+                    "%s but expected to find inferred shape %s.\n"
+                    "Full kwargs: %s"
+                    % (
+                        layer_cls.__name__,
+                        x,
+                        actual_output_shape,
+                        computed_output_shape,
+                        kwargs,
+                    )
+                )
+    if expected_output is not None:
+        assert_equal(actual_output, expected_output)
+
+    # test serialization, weight setting at model level
+    model_config = model.get_config()
+    recovered_model = models.Sequential.from_config(model_config, custom_objects)
+    if model.weights:
+        weights = model.get_weights()
+        recovered_model.set_weights(weights)
+        output = recovered_model.predict(input_data)
+        assert_equal(output, actual_output)
+
+    # for further checks in the caller function
+    return actual_output
+
+
+_thread_local_data = threading.local()
+_thread_local_data.model_type = None
+_thread_local_data.run_eagerly = None
+_thread_local_data.saved_model_format = None
+_thread_local_data.save_kwargs = None
+
+
+def should_run_eagerly():
+    """Returns whether the models we are testing should be run eagerly."""
+    return _thread_local_data.run_eagerly and tf.executing_eagerly()
diff --git a/tensorflow_addons/utils/types.py b/tensorflow_addons/utils/types.py
index de8da2a5dd..6b8c00e5ea 100644
--- a/tensorflow_addons/utils/types.py
+++ b/tensorflow_addons/utils/types.py
@@ -22,15 +22,22 @@
 
 from packaging.version import Version
 
-# TODO: Remove once https://github.com/tensorflow/tensorflow/issues/44613 is resolved
-if Version(tf.__version__).release >= Version("2.13").release:
-    # New versions of Keras require importing from `keras.src` when
-    # importing internal symbols.
-    from keras.src.engine import keras_tensor
+# Find KerasTensor.
+if Version(tf.__version__).release >= Version("2.16").release:
+    # Determine if loading keras 2 or 3.
+    if (
+        hasattr(tf.keras, "version")
+        and Version(tf.keras.version()).release >= Version("3.0").release
+    ):
+        from keras import KerasTensor
+    else:
+        from tf_keras.src.engine.keras_tensor import KerasTensor
+elif Version(tf.__version__).release >= Version("2.13").release:
+    from keras.src.engine.keras_tensor import KerasTensor
 elif Version(tf.__version__).release >= Version("2.5").release:
-    from keras.engine import keras_tensor
+    from keras.engine.keras_tensor import KerasTensor
 else:
-    from tensorflow.python.keras.engine import keras_tensor
+    from tensorflow.python.keras.engine.keras_tensor import KerasTensor
 
 
 Number = Union[
@@ -68,7 +75,7 @@
     tf.Tensor,
     tf.SparseTensor,
     tf.Variable,
-    keras_tensor.KerasTensor,
+    KerasTensor,
 ]
 FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32, np.float64]
 AcceptableDTypes = Union[tf.DType, np.dtype, type, int, str, None]