Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Lok <[email protected]>
  • Loading branch information
daniellok-db committed May 10, 2024
1 parent 8605e0d commit 073a212
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 2 deletions.
5 changes: 5 additions & 0 deletions mlflow/pyfunc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def _save_model_with_class_artifacts_params(
.. Note:: Experimental: This parameter may change or be removed in a future release
without warning.
"""
_logger.warning("@@@ TEST", mlflow_model, artifacts)
if mlflow_model is None:
mlflow_model = Model()

Expand Down Expand Up @@ -384,16 +385,20 @@ def _save_model_with_class_artifacts_params(
mlflow_model.model_size_bytes = size
mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))

_logger.warning("@@@ REQUIREMENTS", conda_env, pip_requirements, extra_pip_requirements)

if conda_env is None:
if pip_requirements is None:
default_reqs = get_default_pip_requirements()
_logger.warning("@@@ DEFAULTS", default_reqs)
# To ensure `_load_pyfunc` can successfully load the model during the dependency
# inference, `mlflow_model.save` must be called beforehand to save an MLmodel file.
inferred_reqs = mlflow.models.infer_pip_requirements(
path,
mlflow.pyfunc.FLAVOR_NAME,
fallback=default_reqs,
)
_logger.warning("@@@ INFERRED", inferred_reqs)
default_reqs = sorted(set(inferred_reqs).union(default_reqs))
else:
default_reqs = None
Expand Down
1 change: 1 addition & 0 deletions mlflow/utils/_capture_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __exit__(self, *_, **__):
# Revert the patches
builtins.__import__ = self.original_import
importlib.import_module = self.original_import_module
print("@@@ imports", self.imported_modules)

Check failure on line 104 in mlflow/utils/_capture_modules.py

View workflow job for this annotation

GitHub Actions / lint

`print` found. See https://docs.astral.sh/ruff/rules/T201 for how to fix this error.


def parse_args():
Expand Down
4 changes: 3 additions & 1 deletion mlflow/utils/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,10 @@ def infer_pip_requirements(model_uri, flavor, fallback=None):
"""
try:
_logger.warning("@@@ INSIDE INFER", model_uri, flavor)
return _infer_requirements(model_uri, flavor)
except Exception:
except Exception as e:
_logger.warning("@@@ ERROR INFERRING", e)
if fallback is not None:
_logger.warning(
msg=_INFER_PIP_REQUIREMENTS_GENERAL_ERROR_MESSAGE.format(
Expand Down
11 changes: 10 additions & 1 deletion mlflow/utils/requirements_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def _run_command(cmd, timeout_seconds, env=None):
stdout, stderr = proc.communicate()
stdout = stdout.decode("utf-8")
stderr = stderr.decode("utf-8")
_logger.warning("@@@@ STDOUT", stdout, stderr)
if proc.returncode != 0:
msg = "\n".join(
[
Expand Down Expand Up @@ -335,7 +336,9 @@ def _capture_imported_modules(model_uri, flavor):
)

with open(output_file) as f:
return f.read().splitlines()
contents = f.read().splitlines()
_logger.warning("@@@ OUTPUT FILE CONTENTS", contents)
return contents


DATABRICKS_MODULES_TO_PACKAGES = {
Expand Down Expand Up @@ -424,9 +427,13 @@ def _infer_requirements(model_uri, flavor):
_PYPI_PACKAGE_INDEX = _load_pypi_package_index()

modules = _capture_imported_modules(model_uri, flavor)
_logger.warning("@@@ MODULES", modules)
packages = _flatten([_MODULES_TO_PACKAGES.get(module, []) for module in modules])
_logger.warning("@@@ PACKAGES 1", packages)
packages = map(_normalize_package_name, packages)
_logger.warning("@@@ PACKAGES 2", packages)
packages = _prune_packages(packages)
_logger.warning("@@@ PACKAGES 3", packages)
excluded_packages = [
# Certain packages (e.g. scikit-learn 0.24.2) imports `setuptools` or `pkg_resources`
# (a module provided by `setuptools`) to process or interact with package metadata.
Expand All @@ -439,9 +446,11 @@ def _infer_requirements(model_uri, flavor):
*_MODULES_TO_PACKAGES.get("mlflow", []),
]
packages = packages - set(excluded_packages)
_logger.warning("@@@ PACKAGES 4", packages, excluded_packages)

# manually exclude mlflow[gateway] as it isn't listed separately in PYPI_PACKAGE_INDEX
unrecognized_packages = packages - _PYPI_PACKAGE_INDEX.package_names - {"mlflow[gateway]"}
_logger.warning("@@@ PACKAGES 5", unrecognized_packages)
if unrecognized_packages:
_logger.warning(
"The following packages were not found in the public PyPI package index as of"
Expand Down

0 comments on commit 073a212

Please sign in to comment.