-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add keras flavor for Keras 3 support #10830
Add keras flavor for Keras 3 support #10830
Conversation
Documentation preview for 4f876fd will be available here when this CircleCI job completes successfully. More info
|
docs/source/deep-learning/keras/quickstart/quickstart_keras_core.ipynb
Outdated
Show resolved
Hide resolved
setup.py
Outdated
@@ -96,7 +96,7 @@ def run(self): | |||
print("\n".join(dependencies)) | |||
|
|||
|
|||
MINIMUM_SUPPORTED_PYTHON_VERSION = "3.8" | |||
MINIMUM_SUPPORTED_PYTHON_VERSION = "3.9" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MINIMUM_SUPPORTED_PYTHON_VERSION = "3.9" | |
MINIMUM_SUPPORTED_PYTHON_VERSION = "3.8" |
mlflow/ml_package_versions.py
Outdated
}, | ||
"models": { | ||
"minimum": "3.0.2", | ||
"maximum": "3.10.0", # Dummy version number to remove cap. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can use the latest version. We don't know if the current implementation really works with 3.10.0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sgtm!
@@ -1,3 +0,0 @@ | |||
from mlflow.keras_core.callback import MLflowCallback | |||
|
|||
__all__ = ["MLflowCallback"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we raise a deprecation warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it was marked as experimental, and I don't think anyone is using it, so I guess we can just delete it.
log_every_n_steps=log_every_n_steps, | ||
) | ||
callbacks = _add_mlflow_to_keras_callbacks(callbacks, mlflow_callback) | ||
kwargs["callbacks"] = callbacks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kwargs["callbacks"] = callbacks | |
kwargs["callbacks"] = [*callbacks, mlflow_callback] |
_add_mlflow_to_keras_callbacks
can just check if callbacks
contains an MLflowCallback
object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense to me, I am renaming it to _check_existing_mlflow_callback
and append the mlflow_callback
beforehand.
_logger.warning(f"Failed to log dataset information to MLflow. Reason: {e}") | ||
|
||
# Add `MLflowCallback` to the callback list. | ||
callbacks = args[5] if len(args) >= 6 else kwargs.get("callbacks", []) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any chance that callbacks
is None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm there is no enforcement, but I doubt anyone will want to do that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
callbacks = args[5] if len(args) >= 6 else kwargs.get("callbacks", []) | |
callbacks = args[5] if len(args) >= 6 else kwargs.get("callbacks") or [] |
in case callbacks
is None
mlflow/keras/callback.py
Outdated
@@ -45,15 +45,13 @@ class MLflowCallback(keras.callbacks.Callback): | |||
label, | |||
batch_size=4, | |||
epochs=2, | |||
callbacks=[mlflow.keras_core.MLflowCallback(run)], | |||
callbacks=[mlflow.keras_core.MLflowCallback()], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
callbacks=[mlflow.keras_core.MLflowCallback()], | |
callbacks=[mlflow.keras.MLflowCallback()], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch!
mlflow/keras/autolog.py
Outdated
@@ -0,0 +1,261 @@ | |||
# MLflow autologging support for Keras 3. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is useless. Let's remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I made a mistake here, it should be a docstring instead of a comment. Basically according to Google python style guide, each module should have a top-level docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you also mlflow/ml-package-versions.yml?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure!
self.metrics_logger.record_metrics(logs, epoch) | ||
log_metrics(logs, step=epoch, synchronous=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious why we need this change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good question - a few months ago we added async logging support to our fluent API, so we on longer need to use the specific logger as before.
AUTOLOGGING_INTEGRATIONS.pop("keras", None) | ||
|
||
|
||
def _create_keras_model(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to test a pytorch model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This Keras model can be a PyTorch, Tensorflow or JAX model, which is controlled by environment variable KERAS_BACKEND
. It sounds a bit weird at first look, I suggest giving it a try, which i find pretty cool!
`save_model()` and `log_model()` produce a pip environment that, at minimum, contains these | ||
requirements. | ||
""" | ||
return [_get_pinned_requirement("keras")] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to detect the backend framework (TF, torch, Jax)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good point, let me add the backend framework information.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chenmoneygithub can you add it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added it to flavor option above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From Haru: let's use this utility function to retrieve the backend package + version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested and verified without any extra code, _get_pinned_requirement
already retrieves the backend requirement, idk why tho.
69ddd1c
to
66c2c9b
Compare
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
c3a8ec6
to
7b4e0f8
Compare
Signed-off-by: chenmoneygithub <[email protected]>
close and reopen to disable the CI cache. |
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Signed-off-by: chenmoneygithub <[email protected]> Signed-off-by: ernestwong-db <[email protected]>
🛠 DevTools 🛠
Install mlflow from this PR
Checkout with GitHub CLI
Related Issues/PRs
#xxxWhat changes are proposed in this pull request?
Add keras flavor for Keras 3 support, this supports all backends Keras 3 supports, i.e., Tensorflow, PyTorch and JAX.
This code is adapted from tensorflow flavor, but mostly rewritten. Redundant logic is removed.
How is this PR tested?
Does this PR require documentation update?
Release Notes
Is this a user-facing change?
What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/artifacts
: Artifact stores and artifact loggingarea/build
: Build and test infrastructure for MLflowarea/deployments
: MLflow Deployments client APIs, server, and third-party Deployments integrationsarea/docs
: MLflow documentation pagesarea/examples
: Example codearea/model-registry
: Model Registry service, APIs, and the fluent client calls for Model Registryarea/models
: MLmodel format, model serialization/deserialization, flavorsarea/recipes
: Recipes, Recipe APIs, Recipe configs, Recipe Templatesarea/projects
: MLproject format, project running backendsarea/scoring
: MLflow Model server, model deployment tools, Spark UDFsarea/server-infra
: MLflow Tracking server backendarea/tracking
: Tracking Service, tracking client APIs, autologgingInterface
area/uiux
: Front-end, user experience, plotting, JavaScript, JavaScript dev serverarea/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Modelsarea/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registryarea/windows
: Windows supportLanguage
language/r
: R APIs and clientslanguage/java
: Java APIs and clientslanguage/new
: Proposals for new client languagesIntegrations
integrations/azure
: Azure and Azure ML integrationsintegrations/sagemaker
: SageMaker integrationsintegrations/databricks
: Databricks integrationsHow should the PR be classified in the release notes? Choose one:
rn/none
- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/breaking-change
- The PR will be mentioned in the "Breaking Changes" sectionrn/feature
- A new user-facing feature worth mentioning in the release notesrn/bug-fix
- A user-facing bug fix worth mentioning in the release notesrn/documentation
- A user-facing documentation change worth mentioning in the release notes