Skip to content

Commit

Permalink
Fix for keras 3.0 (mlflow#10485)
Browse files Browse the repository at this point in the history
Signed-off-by: serena-ruan <[email protected]>
Co-authored-by: serena-ruan <[email protected]>
  • Loading branch information
serena-ruan and serena-ruan authored Nov 29, 2023
1 parent c0951c8 commit 8062c84
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 69 deletions.
2 changes: 1 addition & 1 deletion mlflow/ml-package-versions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ tensorflow:
requirements:
# Requirements to run tests for keras
">= 0.0.0": ["scikit-learn", "pyspark", "pyarrow", "transformers"]
"< 2.7.0": ["pandas==1.3.5"]
"< 2.7.0": ["pandas>=1.3.5,<2.0"]
">= 2.7.0": ["pandas<2.0"]
# TensorFlow == 2.6.5 are incompatible with SQLAlchemy 2.x due to
# transitive dependency version conflicts with the `typing-extensions` package
Expand Down
78 changes: 60 additions & 18 deletions mlflow/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@

# File name to which custom objects cloudpickle is saved - used during save and load
_CUSTOM_OBJECTS_SAVE_PATH = "custom_objects.cloudpickle"
# File name to which custom objects stored in tensorflow _GLOBAL_CUSTOM_OBJECTS
# is saved - it is automatically detected and used during save and load
_GLOBAL_CUSTOM_OBJECTS_SAVE_PATH = "global_custom_objects.cloudpickle"
_KERAS_MODULE_SPEC_PATH = "keras_module.txt"
_KERAS_SAVE_FORMAT_PATH = "save_format.txt"
# File name to which keras model is saved
Expand Down Expand Up @@ -117,6 +120,18 @@ def get_default_conda_env():
return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements())


def get_global_custom_objects():
"""
:return: A live reference to the global dictionary of custom objects.
"""
try:
from tensorflow.keras.saving import get_custom_objects

return get_custom_objects()
except Exception:
pass


@format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
def log_model(
model,
Expand Down Expand Up @@ -216,7 +231,7 @@ def log_model(
)


def _save_keras_custom_objects(path, custom_objects):
def _save_keras_custom_objects(path, custom_objects, file_name):
"""
Save custom objects dictionary to a cloudpickle file so a model can be easily loaded later.
Expand All @@ -227,10 +242,11 @@ def _save_keras_custom_objects(path, custom_objects):
CloudPickle and restores them automatically when the model is
loaded with :py:func:`mlflow.keras.load_model` and
:py:func:`mlflow.pyfunc.load_model`.
:param file_name: The file name to save the custom objects to.
"""
import cloudpickle

custom_objects_path = os.path.join(path, _CUSTOM_OBJECTS_SAVE_PATH)
custom_objects_path = os.path.join(path, file_name)
with open(custom_objects_path, "wb") as out_f:
cloudpickle.dump(custom_objects, out_f)

Expand Down Expand Up @@ -390,7 +406,12 @@ def save_model(
keras_module = importlib.import_module("tensorflow.keras")
# save custom objects if there are custom objects
if custom_objects is not None:
_save_keras_custom_objects(data_path, custom_objects)
_save_keras_custom_objects(data_path, custom_objects, _CUSTOM_OBJECTS_SAVE_PATH)
# save custom objects stored within _GLOBAL_CUSTOM_OBJECTS
if global_custom_objects := get_global_custom_objects():
_save_keras_custom_objects(
data_path, global_custom_objects, _GLOBAL_CUSTOM_OBJECTS_SAVE_PATH
)

# save keras module spec to path/data/keras_module.txt
with open(os.path.join(data_path, _KERAS_MODULE_SPEC_PATH), "w") as f:
Expand All @@ -407,7 +428,14 @@ def save_model(
# To maintain prior behavior, when the format is HDF5, we save
# with the h5 file extension. Otherwise, model_path is a directory
# where the saved_model.pb will be stored (for SavedModel format)
file_extension = ".h5" if save_format == "h5" else ""
# For tensorflow 2.16.0 (including dev version),
# it only supports saving model in .h5 or .keras format
if save_format == "h5":
file_extension = ".h5"
elif Version(tensorflow.__version__).release >= (2, 16):
file_extension = ".keras"
else:
file_extension = ""
model_path = os.path.join(path, model_subpath) + file_extension
if path.startswith("/dbfs/"):
# The Databricks Filesystem uses a FUSE implementation that does not support
Expand Down Expand Up @@ -462,7 +490,7 @@ def save_model(
# save mlflow_model to path/MLmodel
mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))

include_cloudpickle = custom_objects is not None
include_cloudpickle = custom_objects is not None or get_global_custom_objects() is not None
if conda_env is None:
if pip_requirements is None:
default_reqs = get_default_pip_requirements(include_cloudpickle)
Expand Down Expand Up @@ -495,31 +523,46 @@ def save_model(
_PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))


def _load_custom_objects(path, file_name):
custom_objects_path = None
if os.path.isdir(path):
if os.path.isfile(os.path.join(path, file_name)):
custom_objects_path = os.path.join(path, file_name)
if custom_objects_path is not None:
import cloudpickle

with open(custom_objects_path, "rb") as f:
return cloudpickle.load(f)


def _load_keras_model(model_path, keras_module, save_format, **kwargs):
keras_models = importlib.import_module(keras_module.__name__ + ".models")
custom_objects = kwargs.pop("custom_objects", {})
custom_objects_path = None
if saved_custom_objects := _load_custom_objects(model_path, _CUSTOM_OBJECTS_SAVE_PATH):
saved_custom_objects.update(custom_objects)
custom_objects = saved_custom_objects

if global_custom_objects := _load_custom_objects(model_path, _GLOBAL_CUSTOM_OBJECTS_SAVE_PATH):
global_custom_objects.update(custom_objects)
custom_objects = global_custom_objects

if os.path.isdir(model_path):
if os.path.isfile(os.path.join(model_path, _CUSTOM_OBJECTS_SAVE_PATH)):
custom_objects_path = os.path.join(model_path, _CUSTOM_OBJECTS_SAVE_PATH)
model_path = os.path.join(model_path, _MODEL_SAVE_PATH)
if custom_objects_path is not None:
import cloudpickle

with open(custom_objects_path, "rb") as in_f:
pickled_custom_objects = cloudpickle.load(in_f)
pickled_custom_objects.update(custom_objects)
custom_objects = pickled_custom_objects

# If the save_format is HDF5, then we save with h5 file
# extension to align with prior behavior of mlflow logging
if save_format == "h5":
model_path = model_path + ".h5"
model_path += ".h5"
# Since TF 2.16.0, it only supports saving model in .h5 or .keras format.
# But for backwards compatibility, we still save model without suffix
# for older versions of TF.
elif os.path.exists(model_path + ".keras"):
model_path += ".keras"

# keras in tensorflow used to have a '-tf' suffix in the version:
# https://github.com/tensorflow/tensorflow/blob/v2.2.1/tensorflow/python/keras/__init__.py#L36
unsuffixed_version = re.sub(r"-tf$", "", _get_keras_version(keras_module))
if save_format == "h5" and Version(unsuffixed_version) >= Version("2.2.3"):
if save_format == "h5" and (2, 2, 3) <= Version(unsuffixed_version).release < (2, 16):
# NOTE: Keras 2.2.3 does not work with unicode paths in python2. Pass in h5py.File instead
# of string to avoid issues.
import h5py
Expand Down Expand Up @@ -709,7 +752,6 @@ def _load_pyfunc(path):
should_compile = save_format == "tf"
K = importlib.import_module(keras_module.__name__ + ".backend")
if K.backend() == "tensorflow":
K.set_learning_phase(0)
m = _load_keras_model(
path, keras_module=keras_module, save_format=save_format, compile=should_compile
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ dependencies:
- pip<=21.2.4
- pip:
- mlflow==1.28.0
- h5py<3.0.0
- tensorflow==2.3.0
- tensorflow==2.6.5
# pin pandas version to avoid pickling error
# AttributeError: Can't get attribute '_unpickle_block'
- pandas==1.3.5
- pandas>=1.3.5,<2.0.0
- protobuf<4.0.0
name: mlflow-128-tf-23-env
name: mlflow-128-tf-26-env
21 changes: 16 additions & 5 deletions tests/tensorflow/test_keras_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def __call__(self):
model_path, keras_model_kwargs={"custom_objects": correct_custom_objects}
)
assert model_loaded is not None
if Version(tf.__version__) <= Version("2.11.0"):
if Version(tf.__version__) <= Version("2.11.0") or Version(tf.__version__).release >= (2, 16):
with pytest.raises(TypeError, match=r".+"):
mlflow.tensorflow.load_model(model_path)
else:
Expand All @@ -318,6 +318,8 @@ def __call__(self):
# validated eagerly. This prevents a TypeError from being thrown as in the above
# expectation catching validation block. The change in logic now permits loading and
# will not raise an Exception, as validated below.
# TF 2.16.0 updates the logic such that if the custom object is not saved with the
# model or supplied in the load_model call, the model will not be loaded.
incorrect_loaded = mlflow.tensorflow.load_model(model_path)
assert incorrect_loaded is not None

Expand Down Expand Up @@ -598,9 +600,14 @@ def test_save_and_load_model_with_tf_save_format(tf_keras_model, model_path, dat
assert not os.path.exists(
os.path.join(model_path, "data", "model.h5")
), "TF model was saved with HDF5 format; expected SavedModel"
assert os.path.isdir(
os.path.join(model_path, "data", "model")
), "Expected directory containing saved_model.pb"
if Version(tf.__version__).release < (2, 16):
assert os.path.isdir(
os.path.join(model_path, "data", "model")
), "Expected directory containing saved_model.pb"
else:
assert os.path.exists(
os.path.join(model_path, "data", "model.keras")
), "Expected model saved as model.keras"

model_loaded = mlflow.tensorflow.load_model(model_path)
np.testing.assert_allclose(model_loaded.predict(data[0]), tf_keras_model.predict(data[0]))
Expand Down Expand Up @@ -697,7 +704,7 @@ def test_virtualenv_subfield_points_to_correct_path(model, model_path):

def save_or_log_keras_model_by_mlflow128(tmp_path, task_type, save_as_type, save_path=None):
tf_tests_dir = os.path.dirname(__file__)
conda_env = get_or_create_conda_env(os.path.join(tf_tests_dir, "mlflow-128-tf-23-env.yaml"))
conda_env = get_or_create_conda_env(os.path.join(tf_tests_dir, "mlflow-128-tf-26-env.yaml"))
output_data_file_path = os.path.join(tmp_path, "output_data.pkl")
tracking_uri = mlflow.get_tracking_uri()
exec_py_path = os.path.join(tf_tests_dir, "save_keras_model.py")
Expand All @@ -721,6 +728,10 @@ def save_or_log_keras_model_by_mlflow128(tmp_path, task_type, save_as_type, save
)


@pytest.mark.skipif(
Version(tf.__version__).release >= (2, 16),
reason="File save format incompatible for tf >= 2.16.0",
)
def test_load_and_predict_keras_model_saved_by_mlflow128(tmp_path, monkeypatch):
mlflow.set_tracking_uri(tmp_path.joinpath("mlruns").as_uri())
monkeypatch.chdir(tmp_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tensorflow.keras.layers import Concatenate, Dense, Input, Lambda
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.utils import register_keras_serializable

import mlflow
import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
Expand Down Expand Up @@ -99,6 +100,10 @@ def single_multidim_tensor_input_model(data):
x, y = data
model = Sequential()

# This decorator injects the decorated class or function into the Keras custom
# object dictionary, so that it can be serialized and deserialized without
# needing an entry in the user-provided custom object dict.
@register_keras_serializable(name="f1")
def f1(z):
from tensorflow.keras import backend as K

Expand Down Expand Up @@ -130,13 +135,14 @@ def multi_multidim_tensor_input_model(data):
input_a = Input(shape=(2, 3), name="a")
input_b = Input(shape=(2, 5), name="b")

def f1(z):
@register_keras_serializable(name="f2")
def f2(z):
from tensorflow.keras import backend as K

return K.mean(z, axis=2)

input_a_sum = Lambda(f1)(input_a)
input_b_sum = Lambda(f1)(input_b)
input_a_sum = Lambda(f2)(input_a)
input_b_sum = Lambda(f2)(input_b)

output = Dense(1)(Dense(3, input_dim=4)(Concatenate()([input_a_sum, input_b_sum])))
model = Model(inputs=[input_a, input_b], outputs=output)
Expand Down
6 changes: 2 additions & 4 deletions tests/tensorflow/test_load_saved_tensorflow_estimator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import collections
import json
import os
import pickle

import iris_data_utils
import numpy as np
Expand Down Expand Up @@ -50,7 +49,7 @@ def model_path(tmp_path):

def save_or_log_tf_model_by_mlflow128(tmp_path, model_type, task_type, save_path=None):
tf_tests_dir = os.path.dirname(__file__)
conda_env = get_or_create_conda_env(os.path.join(tf_tests_dir, "mlflow-128-tf-23-env.yaml"))
conda_env = get_or_create_conda_env(os.path.join(tf_tests_dir, "mlflow-128-tf-26-env.yaml"))
output_data_file_path = os.path.join(tmp_path, "output_data.pkl")
tracking_uri = mlflow.get_tracking_uri()
exec_py_path = os.path.join(tf_tests_dir, "save_tf_estimator_model.py")
Expand All @@ -64,8 +63,7 @@ def save_or_log_tf_model_by_mlflow128(tmp_path, model_type, task_type, save_path
f"--task_type {task_type} "
f"--save_path {save_path if save_path else 'none'}",
)
with open(output_data_file_path, "rb") as f:
return ModelDataInfo(*pickle.load(f))
return ModelDataInfo(*pd.read_pickle(output_data_file_path))


def test_load_model_from_remote_uri_succeeds(tmp_path, model_path, mock_s3_bucket, monkeypatch):
Expand Down
8 changes: 5 additions & 3 deletions tests/tensorflow/test_mlflow_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def test_tf_mlflow_callback(log_every_epoch, log_every_n_steps):
label,
validation_data=(data, label),
batch_size=4,
epochs=2,
# Increase the epochs size so that logs
# are flushed correctly
epochs=5,
callbacks=[mlflow_callback],
)

Expand All @@ -49,5 +51,5 @@ def test_tf_mlflow_callback(log_every_epoch, log_every_n_steps):

assert "loss" in run_metrics
assert "sparse_categorical_accuracy" in run_metrics
assert model_info["optimizer_name"] == "Adam"
assert model_info["optimizer_learning_rate"] == "0.001"
assert model_info["optimizer_name"].lower() == "adam"
np.testing.assert_almost_equal(float(model_info["optimizer_learning_rate"]), 0.001)
Loading

0 comments on commit 8062c84

Please sign in to comment.