Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add keras mixin #230

Merged
merged 10 commits into from
Aug 10, 2021
24 changes: 19 additions & 5 deletions src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,25 @@

_tf_version = "N/A"
_tf_available = False
try:
_tf_version = importlib_metadata.version("tensorflow")
_tf_available = True
except importlib_metadata.PackageNotFoundError:
pass
_tf_candidates = (
"tensorflow",
"tensorflow-cpu",
"tensorflow-gpu",
"tf-nightly",
"tf-nightly-cpu",
"tf-nightly-gpu",
"intel-tensorflow",
"intel-tensorflow-avx512",
"tensorflow-rocm",
"tensorflow-macos",
)
for package_name in _tf_candidates:
try:
_tf_version = importlib_metadata.version(package_name)
_tf_available = True
break
except importlib_metadata.PackageNotFoundError:
pass


def is_torch_available():
Expand Down
4 changes: 2 additions & 2 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def save_pretrained(
json.dump(config, f)

# saving model weights/files
self._save_pretrained(save_directory)
self._save_pretrained(save_directory, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kwargs are advertised as the kwargs to be passed to push_to_hub yet they're passed to _save_pretrained. I'd argue the kwargs to pass to each method ought to be different!


if push_to_hub:
return self.push_to_hub(save_directory, **kwargs)

def _save_pretrained(self, save_directory):
def _save_pretrained(self, save_directory, **kwargs):
"""
Overwrite this method in subclass to define how to save your model.
"""
Expand Down
108 changes: 108 additions & 0 deletions src/huggingface_hub/keras_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import logging
import os
from pathlib import Path

from huggingface_hub import ModelHubMixin, hf_hub_download


logger = logging.getLogger(__name__)


class KerasModelHubMixin(ModelHubMixin):

_CONFIG_NAME = "config.json"
_WEIGHTS_NAME = "tf_model.h5"

def __init__(self, *args, **kwargs):
nateraw marked this conversation as resolved.
Show resolved Hide resolved
"""
Mix this class with your keras-model class for ease process of saving & loading from huggingface-hub

NOTE - Dummy Inputs are required to save/load models using this mixin. When saving, you are required to either:

1. Assign an attribute to your class, self.dummy_inputs, that defines inputs to be passed to the model's call
function to build the model.
2. Pass the dummy_inputs kwarg to save_pretrained. We will save this along with the model (as if it were an attribute).

Example::

>>> from huggingface_hub import KerasModelHubMixin

>>> class MyModel(tf.keras.Model, KerasModelHubMixin):
... def __init__(self, **kwargs):
... super().__init__()
... self.config = kwargs.pop("config", None)
... self.dummy_inputs = ...
... self.layer = ...
... def call(self, ...)
... return ...

>>> model = MyModel()
>>> model.save_pretrained("mymodel", push_to_hub=False) # Saving model weights in the directory
>>> model.push_to_hub("mymodel", "model-1") # Pushing model-weights to hf-hub

>>> # Downloading weights from hf-hub & model will be initialized from those weights
>>> model = MyModel.from_pretrained("username/mymodel@main")
"""

def _save_pretrained(self, save_directory, dummy_inputs=None, **kwargs):

dummy_inputs = (
dummy_inputs
if dummy_inputs is not None
else getattr(self, "dummy_inputs", None)
)

if dummy_inputs is None:
raise RuntimeError(
"You must either provide dummy inputs or have them assigned as an attribute of this model"
)
nateraw marked this conversation as resolved.
Show resolved Hide resolved

_ = self(dummy_inputs, training=False)

save_directory = Path(save_directory)
model_file = save_directory / self._WEIGHTS_NAME
self.save_weights(model_file)
logger.info(f"Model weights saved in {model_file}")

@classmethod
def _from_pretrained(
cls,
model_id,
revision,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
use_auth_token,
Comment on lines +71 to +77
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intuition tells me all of these should be optional kwargs, even if they're only supposed to be called internally

by_name=False,
**model_kwargs,
):
if os.path.isdir(model_id):
print("Loading weights from local directory")
nateraw marked this conversation as resolved.
Show resolved Hide resolved
model_file = os.path.join(model_id, cls._WEIGHTS_NAME)
else:
model_file = hf_hub_download(
repo_id=model_id,
filename=cls._WEIGHTS_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
)

model = cls(**model_kwargs)

if hasattr(model, "dummy_inputs") and model.dummy_inputs is not None:
raise ValueError("Model must have a dummy_inputs attribute")

_ = model(model.dummy_inputs, training=False)

model.load_weights(model_file, by_name=by_name)

_ = model(model.dummy_inputs, training=False)

nateraw marked this conversation as resolved.
Show resolved Hide resolved
return model
153 changes: 153 additions & 0 deletions tests/test_keras_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import os
import shutil
import time
import unittest

from huggingface_hub import HfApi
from huggingface_hub.file_download import is_tf_available
from huggingface_hub.keras_mixin import KerasModelHubMixin

from .testing_constants import ENDPOINT_STAGING, PASS, USER
from .testing_utils import set_write_permission_and_retry


REPO_NAME = "mixin-repo-{}".format(int(time.time() * 10e3))

WORKING_REPO_SUBDIR = "fixtures/working_repo_3"
WORKING_REPO_DIR = os.path.join(
os.path.dirname(os.path.abspath(__file__)), WORKING_REPO_SUBDIR
)

if is_tf_available():
import tensorflow as tf


def require_tf(test_case):
"""
Decorator marking a test that requires TensorFlow.

These tests are skipped when TensorFlow isn't installed.

"""
if not is_tf_available():
return unittest.skip("test requires Tensorflow")(test_case)
else:
return test_case


if is_tf_available():

class DummyModel(tf.keras.Model, KerasModelHubMixin):
def __init__(self, **kwargs):
super().__init__()
self.config = kwargs.pop("config", None)
self.l1 = tf.keras.layers.Dense(2, activation="relu")
dummy_batch_size = input_dim = 2
self.dummy_inputs = tf.ones([dummy_batch_size, input_dim])

def call(self, x):
return self.l1(x)


else:
DummyModel = None


@require_tf
class HubMixingCommonTest(unittest.TestCase):
_api = HfApi(endpoint=ENDPOINT_STAGING)
Comment on lines +57 to +58
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this super class really needed? We only have one test class so maybe the API wrapper should be instantiated in setUpClass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're probably right. I didn't want to mess with it too much, though. again, just following the previous file. I did find the superclass useful (if I remember right) when I played around with parametrizing the tests so we could run them over TF/PyTorch/etc.



@require_tf
class HubMixingTest(HubMixingCommonTest):
nateraw marked this conversation as resolved.
Show resolved Hide resolved
def tearDown(self) -> None:
try:
shutil.rmtree(WORKING_REPO_DIR, onerror=set_write_permission_and_retry)
except FileNotFoundError:
pass

@classmethod
def setUpClass(cls):
"""
Share this valid token in all tests below.
"""
cls._token = cls._api.login(username=USER, password=PASS)

def test_save_pretrained(self):
model = DummyModel()

model.save_pretrained(f"{WORKING_REPO_DIR}/{REPO_NAME}")
files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}")
self.assertTrue("tf_model.h5" in files)
self.assertEqual(len(files), 1)

model.save_pretrained(
f"{WORKING_REPO_DIR}/{REPO_NAME}", config={"num": 12, "act": "gelu"}
)
files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}")
self.assertTrue("config.json" in files)
self.assertTrue("tf_model.h5" in files)
self.assertEqual(len(files), 2)

def test_keras_from_pretrained_weights(self):
model = DummyModel()
model.dummy_inputs = None
model.save_pretrained(
f"{WORKING_REPO_DIR}/{REPO_NAME}", dummy_inputs=tf.ones([2, 2])
)
assert model.built
new_model = DummyModel.from_pretrained(f"{WORKING_REPO_DIR}/{REPO_NAME}")

# Check the reloaded model's weights match the original model's weights
self.assertTrue(tf.reduce_all(tf.equal(new_model.weights[0], model.weights[0])))

# Check a new model's weights are not the same as the reloaded model's weights
another_model = DummyModel()
another_model(tf.ones([2, 2]))
self.assertFalse(
tf.reduce_all(tf.equal(new_model.weights[0], another_model.weights[0]))
.numpy()
.item()
)

def test_rel_path_from_pretrained(self):
model = DummyModel()
model.save_pretrained(
f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED",
config={"num": 10, "act": "gelu_fast"},
)

model = DummyModel.from_pretrained(
f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED"
)
self.assertTrue(model.config == {"num": 10, "act": "gelu_fast"})

def test_abs_path_from_pretrained(self):
model = DummyModel()
model.save_pretrained(
f"{WORKING_REPO_DIR}/{REPO_NAME}-FROM_PRETRAINED",
config={"num": 10, "act": "gelu_fast"},
)

model = DummyModel.from_pretrained(
f"{WORKING_REPO_DIR}/{REPO_NAME}-FROM_PRETRAINED"
)
self.assertDictEqual(model.config, {"num": 10, "act": "gelu_fast"})

def test_push_to_hub(self):
model = DummyModel()
model.push_to_hub(
repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}-PUSH_TO_HUB",
api_endpoint=ENDPOINT_STAGING,
use_auth_token=self._token,
git_user="ci",
git_email="[email protected]",
config={"num": 7, "act": "gelu_fast"},
)

model_info = self._api.model_info(
f"{USER}/{REPO_NAME}-PUSH_TO_HUB",
)
self.assertEqual(model_info.modelId, f"{USER}/{REPO_NAME}-PUSH_TO_HUB")

self._api.delete_repo(token=self._token, name=f"{REPO_NAME}-PUSH_TO_HUB")