Skip to content

Commit

Permalink
Add keras flavor for Keras 3 support (mlflow#10830)
Browse files Browse the repository at this point in the history
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: ernestwong-db <[email protected]>
  • Loading branch information
chenmoneygithub authored and ernestwong-db committed Feb 6, 2024
1 parent 3e4cd5c commit ecd2e36
Show file tree
Hide file tree
Showing 21 changed files with 1,365 additions and 120 deletions.
3 changes: 2 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@
("py:class", "ConfigDict"),
("py:class", "FieldInfo"),
("py:class", "ComputedFieldInfo"),
("py:class", "keras_core.src.callbacks.callback.Callback"),
("py:class", "keras.src.callbacks.callback.Callback"),
("py:class", "keras.callbacks.Callback"),
]


Expand Down
8 changes: 4 additions & 4 deletions docs/source/deep-learning/keras/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Introduction
Keras is a deep learning API written in Python, running on top of the machine learning platform TensorFlow.
It was developed with a focus on enabling fast experimentation.

Keras 3.0 (Keras Core) makes it possible to run Keras workflows on top of TensorFlow, JAX, and PyTorch.
Keras 3.0 makes it possible to run Keras workflows on top of TensorFlow, JAX, and PyTorch.
It also enables you to seamlessly integrate Keras components (like layers, models, or metrics) as part of
low-level TensorFlow, JAX, and PyTorch workflows.

Expand All @@ -22,18 +22,18 @@ you through how to use the callback for tracking experiments, as well as how to

.. raw:: html

<a href="quickstart/quickstart_keras_core.html" class="download-btn">View the Quickstart</a>
<a href="quickstart/quickstart_keras.html" class="download-btn">View the Quickstart</a>

To download the Keras 3.0 tutorial notebook to run in your environment, click the link below:

.. raw:: html

<a href="https://raw.githubusercontent.com/mlflow/mlflow/master/docs/source/deep-learning/keras/quickstart/quickstart_keras_core.ipynb"
<a href="https://raw.githubusercontent.com/mlflow/mlflow/master/docs/source/deep-learning/keras/quickstart/quickstart_keras.ipynb"
class="notebook-download-btn">Download the Quickstart of MLflow Keras Integration</a><br>


.. toctree::
:maxdepth: 1
:hidden:

quickstart/quickstart_keras_core.ipynb
quickstart/quickstart_keras.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"source": [
"# Get Started with Keras 3.0 + MLflow\n",
"\n",
"This tutorial is an end-to-end tutorial on training a MINST classifier with **Keras 3.0** and logging results with **MLflow**. It will demonstrate the use of `mlflow.keras_core.MLflowCallback`, and how to subclass it to implement custom logging logic.\n",
"This tutorial is an end-to-end tutorial on training a MINST classifier with **Keras 3.0** and logging results with **MLflow**. It will demonstrate the use of `mlflow.keras.MLflowCallback`, and how to subclass it to implement custom logging logic.\n",
"\n",
"**Keras** is a high-level api that is designed to be simple, flexible, and powerful - allowing everyone from beginners to advanced users to quickly build, train, and evaluate models. **Keras 3.0**, or Keras Core, is a full rewrite of the Keras codebase that rebases it on top of a modular backend architecture. It makes it possible to run Keras workflows on top of arbitrary frameworks — starting with TensorFlow, JAX, and PyTorch."
]
Expand All @@ -17,7 +17,7 @@
"source": [
"## Install Packages\n",
"\n",
"`pip install -q keras-core mlflow jax jaxlib torch tensorflow`"
"`pip install -q keras mlflow jax jaxlib torch tensorflow`"
]
},
{
Expand Down Expand Up @@ -54,7 +54,7 @@
}
],
"source": [
"import keras_core\n",
"import keras\n",
"import mlflow\n",
"import numpy as np"
]
Expand Down Expand Up @@ -84,7 +84,7 @@
}
],
"source": [
"(x_train, y_train), (x_test, y_test) = keras_core.datasets.mnist.load_data()\n",
"(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
"x_train = np.expand_dims(x_train, axis=3)\n",
"x_test = np.expand_dims(x_test, axis=3)\n",
"x_train[0].shape"
Expand Down Expand Up @@ -230,14 +230,14 @@
"\n",
"\n",
"def initialize_model():\n",
" model = keras_core.Sequential(\n",
" model = keras.Sequential(\n",
" [\n",
" keras_core.Input(shape=INPUT_SHAPE),\n",
" keras_core.layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\"),\n",
" keras_core.layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\"),\n",
" keras_core.layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\"),\n",
" keras_core.layers.GlobalAveragePooling2D(),\n",
" keras_core.layers.Dense(NUM_CLASSES, activation=\"softmax\"),\n",
" keras.Input(shape=INPUT_SHAPE),\n",
" keras.layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\"),\n",
" keras.layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\"),\n",
" keras.layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\"),\n",
" keras.layers.GlobalAveragePooling2D(),\n",
" keras.layers.Dense(NUM_CLASSES, activation=\"softmax\"),\n",
" ]\n",
" )\n",
" return model\n",
Expand Down Expand Up @@ -295,8 +295,8 @@
"model = initialize_model()\n",
"\n",
"model.compile(\n",
" loss=keras_core.losses.SparseCategoricalCrossentropy(),\n",
" optimizer=keras_core.optimizers.Adam(),\n",
" loss=keras.losses.SparseCategoricalCrossentropy(),\n",
" optimizer=keras.optimizers.Adam(),\n",
" metrics=[\"accuracy\"],\n",
")\n",
"\n",
Expand Down Expand Up @@ -352,8 +352,8 @@
"model = initialize_model()\n",
"\n",
"model.compile(\n",
" loss=keras_core.losses.SparseCategoricalCrossentropy(),\n",
" optimizer=keras_core.optimizers.Adam(),\n",
" loss=keras.losses.SparseCategoricalCrossentropy(),\n",
" optimizer=keras.optimizers.Adam(),\n",
" metrics=[\"accuracy\"],\n",
")\n",
"\n",
Expand All @@ -364,9 +364,7 @@
" batch_size=BATCH_SIZE,\n",
" epochs=EPOCHS,\n",
" validation_split=0.1,\n",
" callbacks=[\n",
" mlflow.keras_core.MLflowCallback(run, log_every_epoch=False, log_every_n_steps=5)\n",
" ],\n",
" callbacks=[mlflow.keras.MLflowCallback(run, log_every_epoch=False, log_every_n_steps=5)],\n",
" )"
]
},
Expand Down Expand Up @@ -448,8 +446,8 @@
"model = initialize_model()\n",
"\n",
"model.compile(\n",
" loss=keras_core.losses.SparseCategoricalCrossentropy(),\n",
" optimizer=keras_core.optimizers.Adam(),\n",
" loss=keras.losses.SparseCategoricalCrossentropy(),\n",
" optimizer=keras.optimizers.Adam(),\n",
" metrics=[\"accuracy\"],\n",
")\n",
"\n",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mlflow.keras_core
mlflow.keras
==================

.. automodule:: mlflow.keras_core
.. automodule:: mlflow.keras
:members:
:undoc-members:
:show-inheritance:
2 changes: 1 addition & 1 deletion mlflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
gluon = LazyLoader("mlflow.gluon", globals(), "mlflow.gluon")
h2o = LazyLoader("mlflow.h2o", globals(), "mlflow.h2o")
johnsnowlabs = LazyLoader("mlflow.johnsnowlabs", globals(), "mlflow.johnsnowlabs")
keras_core = LazyLoader("mlflow.keras_core", globals(), "mlflow.keras_core")
keras = LazyLoader("mlflow.keras", globals(), "mlflow.keras")
langchain = LazyLoader("mlflow.langchain", globals(), "mlflow.langchain")
lightgbm = LazyLoader("mlflow.lightgbm", globals(), "mlflow.lightgbm")
llm = LazyLoader("mlflow.llm", globals(), "mlflow.llm")
Expand Down
31 changes: 23 additions & 8 deletions mlflow/keras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
from mlflow.tensorflow import (
# Redirect `mlflow.keras._load_pyfunc` to `mlflow.tensorflow._load_pyfunc`,
# For backwards compatibility on loading keras model saved by old mlflow versions.
_load_pyfunc, # noqa: F401
autolog, # noqa: F401
load_model, # noqa: F401
log_model, # noqa: F401
save_model, # noqa: F401
# MLflow Keras 3 flavor.

from mlflow.keras.autolog import autolog
from mlflow.keras.callback import MLflowCallback
from mlflow.keras.load import _load_pyfunc, load_model
from mlflow.keras.save import (
get_default_conda_env,
get_default_pip_requirements,
log_model,
save_model,
)

FLAVOR_NAME = "keras"

__all__ = [
"_load_pyfunc",
"MLflowCallback",
"autolog",
"load_model",
"save_model",
"log_model",
"get_default_pip_requirements",
"get_default_conda_env",
]
Loading

0 comments on commit ecd2e36

Please sign in to comment.