From 8062c84fe643566b61b6ec33e2c3efc3c20f7acb Mon Sep 17 00:00:00 2001
From: Serena Ruan <82044803+serena-ruan@users.noreply.github.com>
Date: Tue, 28 Nov 2023 17:40:14 -0800
Subject: [PATCH] Fix for keras 3.0 (#10485)

Signed-off-by: serena-ruan <serena.ruan@ip-10-110-25-32.us-west-2.compute.internal>
Co-authored-by: serena-ruan <serena.ruan@ip-10-110-25-32.us-west-2.compute.internal>
---
 mlflow/ml-package-versions.yml                |  2 +-
 mlflow/tensorflow/__init__.py                 | 78 ++++++++++++++-----
 ...-23-env.yaml => mlflow-128-tf-26-env.yaml} |  7 +-
 tests/tensorflow/test_keras_model_export.py   | 21 +++--
 ...pyfunc_model_works_with_all_input_types.py | 12 ++-
 .../test_load_saved_tensorflow_estimator.py   |  6 +-
 tests/tensorflow/test_mlflow_callback.py      |  8 +-
 tests/tensorflow/test_tensorflow2_autolog.py  | 66 ++++++++--------
 8 files changed, 131 insertions(+), 69 deletions(-)
 rename tests/tensorflow/{mlflow-128-tf-23-env.yaml => mlflow-128-tf-26-env.yaml} (72%)

diff --git a/mlflow/ml-package-versions.yml b/mlflow/ml-package-versions.yml
index bec9c35aabe99..fa5e5d3b8a64b 100644
--- a/mlflow/ml-package-versions.yml
+++ b/mlflow/ml-package-versions.yml
@@ -86,7 +86,7 @@ tensorflow:
     requirements:
       # Requirements to run tests for keras
       ">= 0.0.0": ["scikit-learn", "pyspark", "pyarrow", "transformers"]
-      "< 2.7.0": ["pandas==1.3.5"]
+      "< 2.7.0": ["pandas>=1.3.5,<2.0"]
       ">= 2.7.0": ["pandas<2.0"]
       # TensorFlow == 2.6.5 are incompatible with SQLAlchemy 2.x due to
       # transitive dependency version conflicts with the `typing-extensions` package
diff --git a/mlflow/tensorflow/__init__.py b/mlflow/tensorflow/__init__.py
index 41048071e185f..ad196d579db33 100644
--- a/mlflow/tensorflow/__init__.py
+++ b/mlflow/tensorflow/__init__.py
@@ -85,6 +85,9 @@
 
 # File name to which custom objects cloudpickle is saved - used during save and load
 _CUSTOM_OBJECTS_SAVE_PATH = "custom_objects.cloudpickle"
+# File name to which custom objects stored in tensorflow _GLOBAL_CUSTOM_OBJECTS
+# is saved - it is automatically detected and used during save and load
+_GLOBAL_CUSTOM_OBJECTS_SAVE_PATH = "global_custom_objects.cloudpickle"
 _KERAS_MODULE_SPEC_PATH = "keras_module.txt"
 _KERAS_SAVE_FORMAT_PATH = "save_format.txt"
 # File name to which keras model is saved
@@ -117,6 +120,18 @@ def get_default_conda_env():
     return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements())
 
 
+def get_global_custom_objects():
+    """
+    :return: A live reference to the global dictionary of custom objects.
+    """
+    try:
+        from tensorflow.keras.saving import get_custom_objects
+
+        return get_custom_objects()
+    except Exception:
+        pass
+
+
 @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
 def log_model(
     model,
@@ -216,7 +231,7 @@ def log_model(
     )
 
 
-def _save_keras_custom_objects(path, custom_objects):
+def _save_keras_custom_objects(path, custom_objects, file_name):
     """
     Save custom objects dictionary to a cloudpickle file so a model can be easily loaded later.
 
@@ -227,10 +242,11 @@ def _save_keras_custom_objects(path, custom_objects):
                            CloudPickle and restores them automatically when the model is
                            loaded with :py:func:`mlflow.keras.load_model` and
                            :py:func:`mlflow.pyfunc.load_model`.
+    :param file_name: The file name to save the custom objects to.
     """
     import cloudpickle
 
-    custom_objects_path = os.path.join(path, _CUSTOM_OBJECTS_SAVE_PATH)
+    custom_objects_path = os.path.join(path, file_name)
     with open(custom_objects_path, "wb") as out_f:
         cloudpickle.dump(custom_objects, out_f)
 
@@ -390,7 +406,12 @@ def save_model(
         keras_module = importlib.import_module("tensorflow.keras")
         # save custom objects if there are custom objects
         if custom_objects is not None:
-            _save_keras_custom_objects(data_path, custom_objects)
+            _save_keras_custom_objects(data_path, custom_objects, _CUSTOM_OBJECTS_SAVE_PATH)
+        # save custom objects stored within _GLOBAL_CUSTOM_OBJECTS
+        if global_custom_objects := get_global_custom_objects():
+            _save_keras_custom_objects(
+                data_path, global_custom_objects, _GLOBAL_CUSTOM_OBJECTS_SAVE_PATH
+            )
 
         # save keras module spec to path/data/keras_module.txt
         with open(os.path.join(data_path, _KERAS_MODULE_SPEC_PATH), "w") as f:
@@ -407,7 +428,14 @@ def save_model(
         # To maintain prior behavior, when the format is HDF5, we save
         # with the h5 file extension. Otherwise, model_path is a directory
         # where the saved_model.pb will be stored (for SavedModel format)
-        file_extension = ".h5" if save_format == "h5" else ""
+        # For tensorflow 2.16.0 (including dev version),
+        # it only supports saving model in .h5 or .keras format
+        if save_format == "h5":
+            file_extension = ".h5"
+        elif Version(tensorflow.__version__).release >= (2, 16):
+            file_extension = ".keras"
+        else:
+            file_extension = ""
         model_path = os.path.join(path, model_subpath) + file_extension
         if path.startswith("/dbfs/"):
             # The Databricks Filesystem uses a FUSE implementation that does not support
@@ -462,7 +490,7 @@ def save_model(
     # save mlflow_model to path/MLmodel
     mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
 
-    include_cloudpickle = custom_objects is not None
+    include_cloudpickle = custom_objects is not None or get_global_custom_objects() is not None
     if conda_env is None:
         if pip_requirements is None:
             default_reqs = get_default_pip_requirements(include_cloudpickle)
@@ -495,31 +523,46 @@ def save_model(
     _PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))
 
 
+def _load_custom_objects(path, file_name):
+    custom_objects_path = None
+    if os.path.isdir(path):
+        if os.path.isfile(os.path.join(path, file_name)):
+            custom_objects_path = os.path.join(path, file_name)
+    if custom_objects_path is not None:
+        import cloudpickle
+
+        with open(custom_objects_path, "rb") as f:
+            return cloudpickle.load(f)
+
+
 def _load_keras_model(model_path, keras_module, save_format, **kwargs):
     keras_models = importlib.import_module(keras_module.__name__ + ".models")
     custom_objects = kwargs.pop("custom_objects", {})
-    custom_objects_path = None
+    if saved_custom_objects := _load_custom_objects(model_path, _CUSTOM_OBJECTS_SAVE_PATH):
+        saved_custom_objects.update(custom_objects)
+        custom_objects = saved_custom_objects
+
+    if global_custom_objects := _load_custom_objects(model_path, _GLOBAL_CUSTOM_OBJECTS_SAVE_PATH):
+        global_custom_objects.update(custom_objects)
+        custom_objects = global_custom_objects
+
     if os.path.isdir(model_path):
-        if os.path.isfile(os.path.join(model_path, _CUSTOM_OBJECTS_SAVE_PATH)):
-            custom_objects_path = os.path.join(model_path, _CUSTOM_OBJECTS_SAVE_PATH)
         model_path = os.path.join(model_path, _MODEL_SAVE_PATH)
-    if custom_objects_path is not None:
-        import cloudpickle
-
-        with open(custom_objects_path, "rb") as in_f:
-            pickled_custom_objects = cloudpickle.load(in_f)
-            pickled_custom_objects.update(custom_objects)
-            custom_objects = pickled_custom_objects
 
     # If the save_format is HDF5, then we save with h5 file
     # extension to align with prior behavior of mlflow logging
     if save_format == "h5":
-        model_path = model_path + ".h5"
+        model_path += ".h5"
+    # Since TF 2.16.0, it only supports saving model in .h5 or .keras format.
+    # But for backwards compatibility, we still save model without suffix
+    # for older versions of TF.
+    elif os.path.exists(model_path + ".keras"):
+        model_path += ".keras"
 
     # keras in tensorflow used to have a '-tf' suffix in the version:
     # https://github.com/tensorflow/tensorflow/blob/v2.2.1/tensorflow/python/keras/__init__.py#L36
     unsuffixed_version = re.sub(r"-tf$", "", _get_keras_version(keras_module))
-    if save_format == "h5" and Version(unsuffixed_version) >= Version("2.2.3"):
+    if save_format == "h5" and (2, 2, 3) <= Version(unsuffixed_version).release < (2, 16):
         # NOTE: Keras 2.2.3 does not work with unicode paths in python2. Pass in h5py.File instead
         # of string to avoid issues.
         import h5py
@@ -709,7 +752,6 @@ def _load_pyfunc(path):
         should_compile = save_format == "tf"
         K = importlib.import_module(keras_module.__name__ + ".backend")
         if K.backend() == "tensorflow":
-            K.set_learning_phase(0)
             m = _load_keras_model(
                 path, keras_module=keras_module, save_format=save_format, compile=should_compile
             )
diff --git a/tests/tensorflow/mlflow-128-tf-23-env.yaml b/tests/tensorflow/mlflow-128-tf-26-env.yaml
similarity index 72%
rename from tests/tensorflow/mlflow-128-tf-23-env.yaml
rename to tests/tensorflow/mlflow-128-tf-26-env.yaml
index 4d788e58faecf..2fae2d63d8af4 100644
--- a/tests/tensorflow/mlflow-128-tf-23-env.yaml
+++ b/tests/tensorflow/mlflow-128-tf-26-env.yaml
@@ -5,10 +5,9 @@ dependencies:
   - pip<=21.2.4
   - pip:
       - mlflow==1.28.0
-      - h5py<3.0.0
-      - tensorflow==2.3.0
+      - tensorflow==2.6.5
       # pin pandas version to avoid pickling error
       # AttributeError: Can't get attribute '_unpickle_block'
-      - pandas==1.3.5
+      - pandas>=1.3.5,<2.0.0
       - protobuf<4.0.0
-name: mlflow-128-tf-23-env
+name: mlflow-128-tf-26-env
diff --git a/tests/tensorflow/test_keras_model_export.py b/tests/tensorflow/test_keras_model_export.py
index 5e32051d837b5..5776e26b768a1 100644
--- a/tests/tensorflow/test_keras_model_export.py
+++ b/tests/tensorflow/test_keras_model_export.py
@@ -309,7 +309,7 @@ def __call__(self):
         model_path, keras_model_kwargs={"custom_objects": correct_custom_objects}
     )
     assert model_loaded is not None
-    if Version(tf.__version__) <= Version("2.11.0"):
+    if Version(tf.__version__) <= Version("2.11.0") or Version(tf.__version__).release >= (2, 16):
         with pytest.raises(TypeError, match=r".+"):
             mlflow.tensorflow.load_model(model_path)
     else:
@@ -318,6 +318,8 @@ def __call__(self):
         # validated eagerly. This prevents a TypeError from being thrown as in the above
         # expectation catching validation block. The change in logic now permits loading and
         # will not raise an Exception, as validated below.
+        # TF 2.16.0 updates the logic such that if the custom object is not saved with the
+        # model or supplied in the load_model call, the model will not be loaded.
         incorrect_loaded = mlflow.tensorflow.load_model(model_path)
         assert incorrect_loaded is not None
 
@@ -598,9 +600,14 @@ def test_save_and_load_model_with_tf_save_format(tf_keras_model, model_path, dat
     assert not os.path.exists(
         os.path.join(model_path, "data", "model.h5")
     ), "TF model was saved with HDF5 format; expected SavedModel"
-    assert os.path.isdir(
-        os.path.join(model_path, "data", "model")
-    ), "Expected directory containing saved_model.pb"
+    if Version(tf.__version__).release < (2, 16):
+        assert os.path.isdir(
+            os.path.join(model_path, "data", "model")
+        ), "Expected directory containing saved_model.pb"
+    else:
+        assert os.path.exists(
+            os.path.join(model_path, "data", "model.keras")
+        ), "Expected model saved as model.keras"
 
     model_loaded = mlflow.tensorflow.load_model(model_path)
     np.testing.assert_allclose(model_loaded.predict(data[0]), tf_keras_model.predict(data[0]))
@@ -697,7 +704,7 @@ def test_virtualenv_subfield_points_to_correct_path(model, model_path):
 
 def save_or_log_keras_model_by_mlflow128(tmp_path, task_type, save_as_type, save_path=None):
     tf_tests_dir = os.path.dirname(__file__)
-    conda_env = get_or_create_conda_env(os.path.join(tf_tests_dir, "mlflow-128-tf-23-env.yaml"))
+    conda_env = get_or_create_conda_env(os.path.join(tf_tests_dir, "mlflow-128-tf-26-env.yaml"))
     output_data_file_path = os.path.join(tmp_path, "output_data.pkl")
     tracking_uri = mlflow.get_tracking_uri()
     exec_py_path = os.path.join(tf_tests_dir, "save_keras_model.py")
@@ -721,6 +728,10 @@ def save_or_log_keras_model_by_mlflow128(tmp_path, task_type, save_as_type, save
         )
 
 
+@pytest.mark.skipif(
+    Version(tf.__version__).release >= (2, 16),
+    reason="File save format incompatible for tf >= 2.16.0",
+)
 def test_load_and_predict_keras_model_saved_by_mlflow128(tmp_path, monkeypatch):
     mlflow.set_tracking_uri(tmp_path.joinpath("mlruns").as_uri())
     monkeypatch.chdir(tmp_path)
diff --git a/tests/tensorflow/test_keras_pyfunc_model_works_with_all_input_types.py b/tests/tensorflow/test_keras_pyfunc_model_works_with_all_input_types.py
index 73f6ae57e8695..23c48407fa0d1 100644
--- a/tests/tensorflow/test_keras_pyfunc_model_works_with_all_input_types.py
+++ b/tests/tensorflow/test_keras_pyfunc_model_works_with_all_input_types.py
@@ -10,6 +10,7 @@
 from tensorflow.keras.layers import Concatenate, Dense, Input, Lambda
 from tensorflow.keras.models import Model, Sequential
 from tensorflow.keras.optimizers import SGD
+from tensorflow.keras.utils import register_keras_serializable
 
 import mlflow
 import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
@@ -99,6 +100,10 @@ def single_multidim_tensor_input_model(data):
     x, y = data
     model = Sequential()
 
+    # This decorator injects the decorated class or function into the Keras custom
+    # object dictionary, so that it can be serialized and deserialized without
+    # needing an entry in the user-provided custom object dict.
+    @register_keras_serializable(name="f1")
     def f1(z):
         from tensorflow.keras import backend as K
 
@@ -130,13 +135,14 @@ def multi_multidim_tensor_input_model(data):
     input_a = Input(shape=(2, 3), name="a")
     input_b = Input(shape=(2, 5), name="b")
 
-    def f1(z):
+    @register_keras_serializable(name="f2")
+    def f2(z):
         from tensorflow.keras import backend as K
 
         return K.mean(z, axis=2)
 
-    input_a_sum = Lambda(f1)(input_a)
-    input_b_sum = Lambda(f1)(input_b)
+    input_a_sum = Lambda(f2)(input_a)
+    input_b_sum = Lambda(f2)(input_b)
 
     output = Dense(1)(Dense(3, input_dim=4)(Concatenate()([input_a_sum, input_b_sum])))
     model = Model(inputs=[input_a, input_b], outputs=output)
diff --git a/tests/tensorflow/test_load_saved_tensorflow_estimator.py b/tests/tensorflow/test_load_saved_tensorflow_estimator.py
index f7df2ad3d41a1..925c320469a2f 100644
--- a/tests/tensorflow/test_load_saved_tensorflow_estimator.py
+++ b/tests/tensorflow/test_load_saved_tensorflow_estimator.py
@@ -1,7 +1,6 @@
 import collections
 import json
 import os
-import pickle
 
 import iris_data_utils
 import numpy as np
@@ -50,7 +49,7 @@ def model_path(tmp_path):
 
 def save_or_log_tf_model_by_mlflow128(tmp_path, model_type, task_type, save_path=None):
     tf_tests_dir = os.path.dirname(__file__)
-    conda_env = get_or_create_conda_env(os.path.join(tf_tests_dir, "mlflow-128-tf-23-env.yaml"))
+    conda_env = get_or_create_conda_env(os.path.join(tf_tests_dir, "mlflow-128-tf-26-env.yaml"))
     output_data_file_path = os.path.join(tmp_path, "output_data.pkl")
     tracking_uri = mlflow.get_tracking_uri()
     exec_py_path = os.path.join(tf_tests_dir, "save_tf_estimator_model.py")
@@ -64,8 +63,7 @@ def save_or_log_tf_model_by_mlflow128(tmp_path, model_type, task_type, save_path
         f"--task_type {task_type} "
         f"--save_path {save_path if save_path else 'none'}",
     )
-    with open(output_data_file_path, "rb") as f:
-        return ModelDataInfo(*pickle.load(f))
+    return ModelDataInfo(*pd.read_pickle(output_data_file_path))
 
 
 def test_load_model_from_remote_uri_succeeds(tmp_path, model_path, mock_s3_bucket, monkeypatch):
diff --git a/tests/tensorflow/test_mlflow_callback.py b/tests/tensorflow/test_mlflow_callback.py
index def03ed338bc5..b7a1830cb4daa 100644
--- a/tests/tensorflow/test_mlflow_callback.py
+++ b/tests/tensorflow/test_mlflow_callback.py
@@ -38,7 +38,9 @@ def test_tf_mlflow_callback(log_every_epoch, log_every_n_steps):
             label,
             validation_data=(data, label),
             batch_size=4,
-            epochs=2,
+            # Increase the epochs size so that logs
+            # are flushed correctly
+            epochs=5,
             callbacks=[mlflow_callback],
         )
 
@@ -49,5 +51,5 @@ def test_tf_mlflow_callback(log_every_epoch, log_every_n_steps):
 
     assert "loss" in run_metrics
     assert "sparse_categorical_accuracy" in run_metrics
-    assert model_info["optimizer_name"] == "Adam"
-    assert model_info["optimizer_learning_rate"] == "0.001"
+    assert model_info["optimizer_name"].lower() == "adam"
+    np.testing.assert_almost_equal(float(model_info["optimizer_learning_rate"]), 0.001)
diff --git a/tests/tensorflow/test_tensorflow2_autolog.py b/tests/tensorflow/test_tensorflow2_autolog.py
index 258f966813500..df728735beb2a 100644
--- a/tests/tensorflow/test_tensorflow2_autolog.py
+++ b/tests/tensorflow/test_tensorflow2_autolog.py
@@ -73,20 +73,16 @@ def _generate_features(pos):
 
 
 def _create_model_for_dict_mapping():
-    model = tf.keras.Sequential()
-    model.add(
-        layers.DenseFeatures(
-            [
-                tf.feature_column.numeric_column("a"),
-                tf.feature_column.numeric_column("b"),
-                tf.feature_column.numeric_column("c"),
-                tf.feature_column.numeric_column("d"),
-            ]
-        )
-    )
-    model.add(layers.Dense(16, activation="relu", input_shape=(4,)))
-    model.add(layers.Dense(3, activation="softmax"))
-
+    inputs = {
+        "a": tf.keras.Input(shape=(1,), name="a"),
+        "b": tf.keras.Input(shape=(1,), name="b"),
+        "c": tf.keras.Input(shape=(1,), name="c"),
+        "d": tf.keras.Input(shape=(1,), name="d"),
+    }
+    concatenated = layers.Concatenate()(inputs.values())
+    x = layers.Dense(16, activation="relu", input_shape=(4,))(concatenated)
+    outputs = layers.Dense(3, activation="softmax")(x)
+    model = tf.keras.Model(inputs=inputs, outputs=outputs)
     model.compile(
         optimizer=tf.keras.optimizers.Adam(), loss="categorical_crossentropy", metrics=["accuracy"]
     )
@@ -403,7 +399,7 @@ def test_tf_keras_autolog_logs_expected_data(tf_keras_random_data_run):
     assert "validation_data" not in data.params
     # Testing optimizer parameters are logged
     assert "opt_name" in data.params
-    assert data.params["opt_name"] == "Adam"
+    assert data.params["opt_name"].lower() == "adam"
     assert "opt_learning_rate" in data.params
     assert "opt_beta_1" in data.params
     assert "opt_beta_2" in data.params
@@ -576,7 +572,18 @@ def test_tf_keras_autolog_implicit_batch_size_works_multi_input(generate_data, b
     Version(tf.__version__) < Version("2.1.4"),
     reason="Does not support passing of generator classes as `x` in `fit`",
 )
-@pytest.mark.parametrize("generator", [__generator, __GeneratorClass])
+@pytest.mark.parametrize(
+    "generator",
+    [
+        __generator,
+        pytest.param(
+            __GeneratorClass,
+            marks=pytest.mark.skipif(
+                Version(tf.__version__).release >= (2, 16), reason="does not support"
+            ),
+        ),
+    ],
+)
 @pytest.mark.parametrize("batch_size", [2, 3, 6])
 def test_tf_keras_autolog_implicit_batch_size_for_generator_dataset_without_side_effects(
     generator,
@@ -617,11 +624,7 @@ def test_tf_keras_autolog_succeeds_for_tf_datasets_lacking_batch_size_info():
     assert not hasattr(train_ds, "_batch_size")
 
     model = tf.keras.Sequential()
-    model.add(
-        tf.keras.Input(
-            100,
-        )
-    )
+    model.add(tf.keras.Input((100,)))
     model.add(tf.keras.layers.Dense(256, activation="relu"))
     model.add(tf.keras.layers.Dropout(rate=0.4))
     model.add(tf.keras.layers.Dense(10, activation="sigmoid"))
@@ -1007,14 +1010,16 @@ def get_text_vec_model(train_samples):
     # Taken from: https://github.com/mlflow/mlflow/issues/3910
 
     # pylint: disable=no-name-in-module
-    from tensorflow.keras.layers.experimental.preprocessing import TextVectorization
+    try:
+        from tensorflow.keras.layers.experimental.preprocessing import TextVectorization
+    except ModuleNotFoundError:
+        from tensorflow.keras.layers import TextVectorization
 
     VOCAB_SIZE = 10
     SEQUENCE_LENGTH = 16
     EMBEDDING_DIM = 16
 
     vectorizer_layer = TextVectorization(
-        input_shape=(1,),
         max_tokens=VOCAB_SIZE,
         output_mode="int",
         output_sequence_length=SEQUENCE_LENGTH,
@@ -1028,14 +1033,13 @@ def get_text_vec_model(train_samples):
                 EMBEDDING_DIM,
                 name="embedding",
                 mask_zero=True,
-                input_shape=(1,),
             ),
             tf.keras.layers.GlobalAveragePooling1D(),
             tf.keras.layers.Dense(16, activation="relu"),
             tf.keras.layers.Dense(1, activation="tanh"),
         ]
     )
-    model.compile(optimizer="adam", loss="mse", metrics="mae")
+    model.compile(optimizer="adam", loss="mse", metrics=["mae"])
     return model
 
 
@@ -1053,14 +1057,10 @@ def test_autolog_text_vec_model(tmp_path):
     """
     mlflow.tensorflow.autolog()
 
-    train_samples = np.array(["this is an example", "another example"])
+    train_samples = np.array(["this is an example", "another example"], dtype=object)
     train_labels = np.array([0.4, 0.2])
     model = get_text_vec_model(train_samples)
 
-    # Saving in the H5 format should fail
-    with pytest.raises(NotImplementedError, match="is not supported in h5"):
-        model.save(str(tmp_path.joinpath("model.h5")), save_format="h5")
-
     with mlflow.start_run() as run:
         model.fit(train_samples, train_labels, epochs=1)
 
@@ -1144,7 +1144,11 @@ def test_fluent_autolog_with_tf_keras_preserves_v2_model_reference():
     mlflow.autolog()
 
     import tensorflow.keras
-    from keras.api._v2.keras import Model as ModelV2
+
+    if Version(tf.__version__).release < (2, 16):
+        from keras.api._v2.keras import Model as ModelV2
+    else:
+        from keras import Model as ModelV2
 
     assert tensorflow.keras.Model is ModelV2