Skip to content

Commit

Permalink
Avoid hard erroring if tf is not installed (#2028)
Browse files Browse the repository at this point in the history
Instead we only want to fail on features that need it (mainly
tokenization).
  • Loading branch information
mattdangerw authored Dec 23, 2024
1 parent 78df6f5 commit 18ae93c
Showing 1 changed file with 34 additions and 45 deletions.
79 changes: 34 additions & 45 deletions keras_hub/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,10 @@

import keras
from absl import logging
from packaging.version import parse

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.utils.keras_utils import print_msg

try:
import tensorflow as tf
except ImportError:
raise ImportError(
"To use `keras_hub`, please install Tensorflow: "
"`pip install tensorflow`. The TensorFlow package is required for data "
"preprocessing with any backend."
)

try:
import kagglehub
from kagglehub.exceptions import KaggleApiHTTPError
Expand Down Expand Up @@ -173,22 +163,8 @@ def get_file(preset, path):
)
else:
raise ValueError(message)

elif scheme in tf.io.gfile.get_registered_schemes():
url = os.path.join(preset, path)
subdir = preset.replace("://", "_").replace("-", "_").replace("/", "_")
filename = os.path.basename(path)
subdir = os.path.join(subdir, os.path.dirname(path))
try:
return copy_gfile_to_cache(
filename,
url,
cache_subdir=os.path.join("models", subdir),
)
except (tf.errors.PermissionDeniedError, tf.errors.NotFoundError) as e:
raise FileNotFoundError(
f"`{path}` doesn't exist in preset directory `{preset}`.",
) from e
elif scheme in tf_registered_schemes():
return tf_copy_gfile_to_cache(preset, path)
elif scheme == HF_SCHEME:
if huggingface_hub is None:
raise ImportError(
Expand Down Expand Up @@ -237,29 +213,48 @@ def get_file(preset, path):
)


def copy_gfile_to_cache(filename, url, cache_subdir):
def tf_registered_schemes():
try:
import tensorflow as tf

return tf.io.gfile.get_registered_schemes()
except ImportError:
return []


def tf_copy_gfile_to_cache(preset, path):
"""Much of this is adapted from get_file of keras core."""
if "KERAS_HOME" in os.environ:
cachdir_base = os.environ.get("KERAS_HOME")
base_dir = os.environ.get("KERAS_HOME")
else:
cachdir_base = os.path.expanduser(os.path.join("~", ".keras"))
if not os.access(cachdir_base, os.W_OK):
cachdir_base = os.path.join("/tmp", ".keras")
cachedir = os.path.join(cachdir_base, cache_subdir)
os.makedirs(cachedir, exist_ok=True)

fpath = os.path.join(cachedir, filename)
if not os.path.exists(fpath):
base_dir = os.path.expanduser(os.path.join("~", ".keras"))
if not os.access(base_dir, os.W_OK):
base_dir = os.path.join("/tmp", ".keras")

url = os.path.join(preset, path)
model_dir = preset.replace("://", "_").replace("-", "_").replace("/", "_")
local_path = os.path.join(base_dir, "models", model_dir, path)

if not os.path.exists(local_path):
print_msg(f"Downloading data from {url}")
try:
tf.io.gfile.copy(url, fpath)
import tensorflow as tf

os.make_dirs(os.path.dirname(local_path), exist_ok=True)
tf.io.gfile.copy(url, local_path)
except Exception as e:
# gfile.copy will leave an empty file after an error.
# Work around this bug.
os.remove(fpath)
os.remove(local_path)
if isinstance(
e, tf.errors.PermissionDeniedError, tf.errors.NotFoundError
):
raise FileNotFoundError(
f"`{path}` doesn't exist in preset directory `{preset}`.",
) from e
raise e

return fpath
return local_path


def check_file_exists(preset, path):
Expand Down Expand Up @@ -394,12 +389,6 @@ def upload_preset(
"Uploading a model to Kaggle Hub requires the `kagglehub` "
"package. Please install with `pip install kagglehub`."
)
if parse(kagglehub.__version__) < parse("0.2.4"):
raise ImportError(
"Uploading a model to Kaggle Hub requires the `kagglehub` "
"package version `0.2.4` or higher. Please upgrade with "
"`pip install --upgrade kagglehub`."
)
kaggle_handle = uri.removeprefix(KAGGLE_PREFIX)
kagglehub.model_upload(kaggle_handle, preset)
elif uri.startswith(HF_PREFIX):
Expand Down

0 comments on commit 18ae93c

Please sign in to comment.