Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add keras flavor for Keras 3 support #10830

Merged
merged 16 commits into from
Feb 1, 2024

Conversation

chenmoneygithub
Copy link
Collaborator

@chenmoneygithub chenmoneygithub commented Jan 16, 2024

🛠 DevTools 🛠

Open in GitHub Codespaces

Install mlflow from this PR

pip install git+https://github.com/mlflow/mlflow.git@refs/pull/10830/merge

Checkout with GitHub CLI

gh pr checkout 10830

Related Issues/PRs

#xxx

What changes are proposed in this pull request?

Add keras flavor for Keras 3 support, this supports all backends Keras 3 supports, i.e., Tensorflow, PyTorch and JAX.

This code is adapted from tensorflow flavor, but mostly rewritten. Redundant logic is removed.

How is this PR tested?

  • Existing unit/integration tests
  • New unit/integration tests
  • Manual tests

Does this PR require documentation update?

  • No. You can skip the rest of this section.
  • Yes. I've updated:
    • Examples
    • API references
    • Instructions

Release Notes

Is this a user-facing change?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release notes for MLflow users.

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/deployments: MLflow Deployments client APIs, server, and third-party Deployments integrations
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/recipes: Recipes, Recipe APIs, Recipe configs, Recipe Templates
  • area/projects: MLproject format, project running backends
  • area/scoring: MLflow Model server, model deployment tools, Spark UDFs
  • area/server-infra: MLflow Tracking server backend
  • area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev server
  • area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

Language

  • language/r: R APIs and clients
  • language/java: Java APIs and clients
  • language/new: Proposals for new client languages

Integrations

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations
  • integrations/databricks: Databricks integrations

How should the PR be classified in the release notes? Choose one:

  • rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • rn/feature - A new user-facing feature worth mentioning in the release notes
  • rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • rn/documentation - A user-facing documentation change worth mentioning in the release notes

Copy link

github-actions bot commented Jan 16, 2024

Documentation preview for 4f876fd will be available here when this CircleCI job completes successfully.

More info

@chenmoneygithub chenmoneygithub marked this pull request as draft January 16, 2024 23:16
@github-actions github-actions bot added area/tracking Tracking service, tracking client APIs, autologging rn/feature Mention under Features in Changelogs. labels Jan 16, 2024
setup.py Outdated
@@ -96,7 +96,7 @@ def run(self):
print("\n".join(dependencies))


MINIMUM_SUPPORTED_PYTHON_VERSION = "3.8"
MINIMUM_SUPPORTED_PYTHON_VERSION = "3.9"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
MINIMUM_SUPPORTED_PYTHON_VERSION = "3.9"
MINIMUM_SUPPORTED_PYTHON_VERSION = "3.8"

},
"models": {
"minimum": "3.0.2",
"maximum": "3.10.0", # Dummy version number to remove cap.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use the latest version. We don't know if the current implementation really works with 3.10.0.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm!

@@ -1,3 +0,0 @@
from mlflow.keras_core.callback import MLflowCallback

__all__ = ["MLflowCallback"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we raise a deprecation warning?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was marked as experimental, and I don't think anyone is using it, so I guess we can just delete it.

log_every_n_steps=log_every_n_steps,
)
callbacks = _add_mlflow_to_keras_callbacks(callbacks, mlflow_callback)
kwargs["callbacks"] = callbacks
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
kwargs["callbacks"] = callbacks
kwargs["callbacks"] = [*callbacks, mlflow_callback]

_add_mlflow_to_keras_callbacks can just check if callbacks contains an MLflowCallback object.

Copy link
Collaborator Author

@chenmoneygithub chenmoneygithub Jan 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense to me, I am renaming it to _check_existing_mlflow_callback and append the mlflow_callback beforehand.

_logger.warning(f"Failed to log dataset information to MLflow. Reason: {e}")

# Add `MLflowCallback` to the callback list.
callbacks = args[5] if len(args) >= 6 else kwargs.get("callbacks", [])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance that callbacks is None?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm there is no enforcement, but I doubt anyone will want to do that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
callbacks = args[5] if len(args) >= 6 else kwargs.get("callbacks", [])
callbacks = args[5] if len(args) >= 6 else kwargs.get("callbacks") or []

in case callbacks is None

@@ -45,15 +45,13 @@ class MLflowCallback(keras.callbacks.Callback):
label,
batch_size=4,
epochs=2,
callbacks=[mlflow.keras_core.MLflowCallback(run)],
callbacks=[mlflow.keras_core.MLflowCallback()],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
callbacks=[mlflow.keras_core.MLflowCallback()],
callbacks=[mlflow.keras.MLflowCallback()],

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

@@ -0,0 +1,261 @@
# MLflow autologging support for Keras 3.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is useless. Let's remove it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I made a mistake here, it should be a docstring instead of a comment. Basically according to Google python style guide, each module should have a top-level docstring.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also mlflow/ml-package-versions.yml?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure!

self.metrics_logger.record_metrics(logs, epoch)
log_metrics(logs, step=epoch, synchronous=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why we need this change

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question - a few months ago we added async logging support to our fluent API, so we on longer need to use the specific logger as before.

AUTOLOGGING_INTEGRATIONS.pop("keras", None)


def _create_keras_model():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to test a pytorch model?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This Keras model can be a PyTorch, Tensorflow or JAX model, which is controlled by environment variable KERAS_BACKEND. It sounds a bit weird at first look, I suggest giving it a try, which i find pretty cool!

`save_model()` and `log_model()` produce a pip environment that, at minimum, contains these
requirements.
"""
return [_get_pinned_requirement("keras")]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to detect the backend framework (TF, torch, Jax)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good point, let me add the backend framework information.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenmoneygithub can you add it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it to flavor option above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From Haru: let's use this utility function to retrieve the backend package + version.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested and verified without any extra code, _get_pinned_requirement already retrieves the backend requirement, idk why tho.

Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
@chenmoneygithub
Copy link
Collaborator Author

close and reopen to disable the CI cache.

Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: chenmoneygithub <[email protected]>
Copy link
Member

@harupy harupy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chenmoneygithub chenmoneygithub merged commit 093782b into mlflow:master Feb 1, 2024
39 checks passed
ernestwong-db pushed a commit to ernestwong-db/mlflow that referenced this pull request Feb 6, 2024
Signed-off-by: chenmoneygithub <[email protected]>
Signed-off-by: ernestwong-db <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area/tracking Tracking service, tracking client APIs, autologging rn/feature Mention under Features in Changelogs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants